diff --git a/src/alttext/descengine.py b/src/alttext/descengine.py index db96623..26f1284 100644 --- a/src/alttext/descengine.py +++ b/src/alttext/descengine.py @@ -1,6 +1,9 @@ from abc import ABC, abstractmethod import base64 import os +import shutil +import subprocess +import uuid import replicate import vertexai @@ -24,6 +27,7 @@ class DescEngine(ABC): pass +### IMPLEMENTATIONS REPLICATE_MODELS = { "blip": "salesforce/blip:2e1dddc8621f72155f24cf2e0adbde548458d3cab9f00c0139eea840d0ac4746", "clip_prefix_caption": "rmokady/clip_prefix_caption:9a34a6339872a03f45236f114321fb51fc7aa8269d38ae0ce5334969981e4cd8", @@ -34,7 +38,6 @@ REPLICATE_MODELS = { } -### IMPLEMENTATIONS class ReplicateAPI(DescEngine): def __init__(self, key: str, model: str = "blip") -> None: self.__setKey(key) @@ -73,6 +76,30 @@ class ReplicateAPI(DescEngine): return output +class BlipLocal(DescEngine): + def __init__(self, path: str) -> None: + self.__setPath(path) + return None + + def __setPath(self, path: str) -> str: + self.path = path + return self.path + + def genDesc(self, imgData: bytes, src: str, context: str = None) -> str: + folderName = uuid.uuid4() + ext = src.split(".")[-1] + os.makedirs(f"{self.path}/{folderName}") + open(f"{self.path}/{folderName}/image.{ext}", "wb+").write(imgData) + subprocess.call( + f"python {self.path}/inference.py -i ./{folderName} --batch 1 --gpu 0", + cwd=f"{self.path}", + ) + desc = open(f"{self.path}/{folderName}/0_captions.txt", "r").read() + shutil.rmtree(f"{self.path}/{folderName}") + desc = desc.split(",") + return desc[1] + + class GoogleVertexAPI(DescEngine): def __init__(self, project_id: str, location: str, gac_path: str) -> None: self.project_id = project_id diff --git a/tests/test.py b/tests/test.py index 6cfe2cf..8a304ec 100644 --- a/tests/test.py +++ b/tests/test.py @@ -24,24 +24,33 @@ HOST1 = "http://127.0.0.1:8001" def testHTML(): print("TESTING HTML") + # alt: alttext.AltTextHTML = alttext.AltTextHTML( + # # descengine.ReplicateAPI(keys.ReplicateEricKey(), "blip"), + # # ocrengine.Tesseract(), + # # langengine.PrivateGPT(HOST1), + # ) + + # alt: alttext.AltTextHTML = alttext.AltTextHTML( + # descengine.BlipLocal("C:/Users/dacru/Desktop/Codebase/ALT/image-captioning"), + # options={"version": 1}, + # ) + alt: alttext.AltTextHTML = alttext.AltTextHTML( - # descengine.GoogleVertexAPI( - # keys.VertexProject(), keys.VertexRegion(), keys.VertexGAC() - # ), - descengine.ReplicateAPI(keys.ReplicateEricKey(), "blip"), - # ocrengine.Tesseract(), - # langengine.PrivateGPT(HOST1), - options={"version": 1}, + descengine.BlipLocal("C:/Users/dacru/Desktop/Codebase/ALT/image-captioning"), + ocrengine.Tesseract(), + langengine.PrivateGPT(HOST1), ) + alt.parseFile(HTML_HUNTING) imgs = alt.getAllImgs() - # src = imgs[5].attrs["src"] - # print(src) + src = imgs[4].attrs["src"] + print(src) + print(alt.genAltText(src)) # desc = alt.genDesc(alt.getImgData(src), src) # print(desc) - associations = alt.genAltAssociations(imgs) - print(associations) + # associations = alt.genAltAssociations(imgs) + # print(associations) if __name__ == "__main__":