Merge pull request #9 from EbookFoundation/dev-david

Added blipLocal DescEngine
pull/15/head
XxMistaCruzxX 2024-02-02 13:40:55 -05:00 committed by GitHub
commit bda65d9138
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 48 additions and 12 deletions

View File

@ -1,6 +1,9 @@
from abc import ABC, abstractmethod
import base64
import os
import shutil
import subprocess
import uuid
import replicate
import vertexai
@ -24,6 +27,7 @@ class DescEngine(ABC):
pass
### IMPLEMENTATIONS
REPLICATE_MODELS = {
"blip": "salesforce/blip:2e1dddc8621f72155f24cf2e0adbde548458d3cab9f00c0139eea840d0ac4746",
"clip_prefix_caption": "rmokady/clip_prefix_caption:9a34a6339872a03f45236f114321fb51fc7aa8269d38ae0ce5334969981e4cd8",
@ -34,7 +38,6 @@ REPLICATE_MODELS = {
}
### IMPLEMENTATIONS
class ReplicateAPI(DescEngine):
def __init__(self, key: str, model: str = "blip") -> None:
self.__setKey(key)
@ -73,6 +76,30 @@ class ReplicateAPI(DescEngine):
return output
class BlipLocal(DescEngine):
def __init__(self, path: str) -> None:
self.__setPath(path)
return None
def __setPath(self, path: str) -> str:
self.path = path
return self.path
def genDesc(self, imgData: bytes, src: str, context: str = None) -> str:
folderName = uuid.uuid4()
ext = src.split(".")[-1]
os.makedirs(f"{self.path}/{folderName}")
open(f"{self.path}/{folderName}/image.{ext}", "wb+").write(imgData)
subprocess.call(
f"python {self.path}/inference.py -i ./{folderName} --batch 1 --gpu 0",
cwd=f"{self.path}",
)
desc = open(f"{self.path}/{folderName}/0_captions.txt", "r").read()
shutil.rmtree(f"{self.path}/{folderName}")
desc = desc.split(",")
return desc[1]
class GoogleVertexAPI(DescEngine):
def __init__(self, project_id: str, location: str, gac_path: str) -> None:
self.project_id = project_id

View File

@ -24,24 +24,33 @@ HOST1 = "http://127.0.0.1:8001"
def testHTML():
print("TESTING HTML")
# alt: alttext.AltTextHTML = alttext.AltTextHTML(
# # descengine.ReplicateAPI(keys.ReplicateEricKey(), "blip"),
# # ocrengine.Tesseract(),
# # langengine.PrivateGPT(HOST1),
# )
# alt: alttext.AltTextHTML = alttext.AltTextHTML(
# descengine.BlipLocal("C:/Users/dacru/Desktop/Codebase/ALT/image-captioning"),
# options={"version": 1},
# )
alt: alttext.AltTextHTML = alttext.AltTextHTML(
# descengine.GoogleVertexAPI(
# keys.VertexProject(), keys.VertexRegion(), keys.VertexGAC()
# ),
descengine.ReplicateAPI(keys.ReplicateEricKey(), "blip"),
# ocrengine.Tesseract(),
# langengine.PrivateGPT(HOST1),
options={"version": 1},
descengine.BlipLocal("C:/Users/dacru/Desktop/Codebase/ALT/image-captioning"),
ocrengine.Tesseract(),
langengine.PrivateGPT(HOST1),
)
alt.parseFile(HTML_HUNTING)
imgs = alt.getAllImgs()
# src = imgs[5].attrs["src"]
# print(src)
src = imgs[4].attrs["src"]
print(src)
print(alt.genAltText(src))
# desc = alt.genDesc(alt.getImgData(src), src)
# print(desc)
associations = alt.genAltAssociations(imgs)
print(associations)
# associations = alt.genAltAssociations(imgs)
# print(associations)
if __name__ == "__main__":