Merge pull request #9 from EbookFoundation/dev-david

Added blipLocal DescEngine
pull/15/head
XxMistaCruzxX 2024-02-02 13:40:55 -05:00 committed by GitHub
commit bda65d9138
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 48 additions and 12 deletions

View File

@ -1,6 +1,9 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import base64 import base64
import os import os
import shutil
import subprocess
import uuid
import replicate import replicate
import vertexai import vertexai
@ -24,6 +27,7 @@ class DescEngine(ABC):
pass pass
### IMPLEMENTATIONS
REPLICATE_MODELS = { REPLICATE_MODELS = {
"blip": "salesforce/blip:2e1dddc8621f72155f24cf2e0adbde548458d3cab9f00c0139eea840d0ac4746", "blip": "salesforce/blip:2e1dddc8621f72155f24cf2e0adbde548458d3cab9f00c0139eea840d0ac4746",
"clip_prefix_caption": "rmokady/clip_prefix_caption:9a34a6339872a03f45236f114321fb51fc7aa8269d38ae0ce5334969981e4cd8", "clip_prefix_caption": "rmokady/clip_prefix_caption:9a34a6339872a03f45236f114321fb51fc7aa8269d38ae0ce5334969981e4cd8",
@ -34,7 +38,6 @@ REPLICATE_MODELS = {
} }
### IMPLEMENTATIONS
class ReplicateAPI(DescEngine): class ReplicateAPI(DescEngine):
def __init__(self, key: str, model: str = "blip") -> None: def __init__(self, key: str, model: str = "blip") -> None:
self.__setKey(key) self.__setKey(key)
@ -73,6 +76,30 @@ class ReplicateAPI(DescEngine):
return output return output
class BlipLocal(DescEngine):
def __init__(self, path: str) -> None:
self.__setPath(path)
return None
def __setPath(self, path: str) -> str:
self.path = path
return self.path
def genDesc(self, imgData: bytes, src: str, context: str = None) -> str:
folderName = uuid.uuid4()
ext = src.split(".")[-1]
os.makedirs(f"{self.path}/{folderName}")
open(f"{self.path}/{folderName}/image.{ext}", "wb+").write(imgData)
subprocess.call(
f"python {self.path}/inference.py -i ./{folderName} --batch 1 --gpu 0",
cwd=f"{self.path}",
)
desc = open(f"{self.path}/{folderName}/0_captions.txt", "r").read()
shutil.rmtree(f"{self.path}/{folderName}")
desc = desc.split(",")
return desc[1]
class GoogleVertexAPI(DescEngine): class GoogleVertexAPI(DescEngine):
def __init__(self, project_id: str, location: str, gac_path: str) -> None: def __init__(self, project_id: str, location: str, gac_path: str) -> None:
self.project_id = project_id self.project_id = project_id

View File

@ -24,24 +24,33 @@ HOST1 = "http://127.0.0.1:8001"
def testHTML(): def testHTML():
print("TESTING HTML") print("TESTING HTML")
# alt: alttext.AltTextHTML = alttext.AltTextHTML(
# # descengine.ReplicateAPI(keys.ReplicateEricKey(), "blip"),
# # ocrengine.Tesseract(),
# # langengine.PrivateGPT(HOST1),
# )
# alt: alttext.AltTextHTML = alttext.AltTextHTML(
# descengine.BlipLocal("C:/Users/dacru/Desktop/Codebase/ALT/image-captioning"),
# options={"version": 1},
# )
alt: alttext.AltTextHTML = alttext.AltTextHTML( alt: alttext.AltTextHTML = alttext.AltTextHTML(
# descengine.GoogleVertexAPI( descengine.BlipLocal("C:/Users/dacru/Desktop/Codebase/ALT/image-captioning"),
# keys.VertexProject(), keys.VertexRegion(), keys.VertexGAC() ocrengine.Tesseract(),
# ), langengine.PrivateGPT(HOST1),
descengine.ReplicateAPI(keys.ReplicateEricKey(), "blip"),
# ocrengine.Tesseract(),
# langengine.PrivateGPT(HOST1),
options={"version": 1},
) )
alt.parseFile(HTML_HUNTING) alt.parseFile(HTML_HUNTING)
imgs = alt.getAllImgs() imgs = alt.getAllImgs()
# src = imgs[5].attrs["src"] src = imgs[4].attrs["src"]
# print(src) print(src)
print(alt.genAltText(src))
# desc = alt.genDesc(alt.getImgData(src), src) # desc = alt.genDesc(alt.getImgData(src), src)
# print(desc) # print(desc)
associations = alt.genAltAssociations(imgs) # associations = alt.genAltAssociations(imgs)
print(associations) # print(associations)
if __name__ == "__main__": if __name__ == "__main__":