commit
bda65d9138
|
@ -1,6 +1,9 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import base64
|
import base64
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import uuid
|
||||||
|
|
||||||
import replicate
|
import replicate
|
||||||
import vertexai
|
import vertexai
|
||||||
|
@ -24,6 +27,7 @@ class DescEngine(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
### IMPLEMENTATIONS
|
||||||
REPLICATE_MODELS = {
|
REPLICATE_MODELS = {
|
||||||
"blip": "salesforce/blip:2e1dddc8621f72155f24cf2e0adbde548458d3cab9f00c0139eea840d0ac4746",
|
"blip": "salesforce/blip:2e1dddc8621f72155f24cf2e0adbde548458d3cab9f00c0139eea840d0ac4746",
|
||||||
"clip_prefix_caption": "rmokady/clip_prefix_caption:9a34a6339872a03f45236f114321fb51fc7aa8269d38ae0ce5334969981e4cd8",
|
"clip_prefix_caption": "rmokady/clip_prefix_caption:9a34a6339872a03f45236f114321fb51fc7aa8269d38ae0ce5334969981e4cd8",
|
||||||
|
@ -34,7 +38,6 @@ REPLICATE_MODELS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
### IMPLEMENTATIONS
|
|
||||||
class ReplicateAPI(DescEngine):
|
class ReplicateAPI(DescEngine):
|
||||||
def __init__(self, key: str, model: str = "blip") -> None:
|
def __init__(self, key: str, model: str = "blip") -> None:
|
||||||
self.__setKey(key)
|
self.__setKey(key)
|
||||||
|
@ -73,6 +76,30 @@ class ReplicateAPI(DescEngine):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class BlipLocal(DescEngine):
|
||||||
|
def __init__(self, path: str) -> None:
|
||||||
|
self.__setPath(path)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __setPath(self, path: str) -> str:
|
||||||
|
self.path = path
|
||||||
|
return self.path
|
||||||
|
|
||||||
|
def genDesc(self, imgData: bytes, src: str, context: str = None) -> str:
|
||||||
|
folderName = uuid.uuid4()
|
||||||
|
ext = src.split(".")[-1]
|
||||||
|
os.makedirs(f"{self.path}/{folderName}")
|
||||||
|
open(f"{self.path}/{folderName}/image.{ext}", "wb+").write(imgData)
|
||||||
|
subprocess.call(
|
||||||
|
f"python {self.path}/inference.py -i ./{folderName} --batch 1 --gpu 0",
|
||||||
|
cwd=f"{self.path}",
|
||||||
|
)
|
||||||
|
desc = open(f"{self.path}/{folderName}/0_captions.txt", "r").read()
|
||||||
|
shutil.rmtree(f"{self.path}/{folderName}")
|
||||||
|
desc = desc.split(",")
|
||||||
|
return desc[1]
|
||||||
|
|
||||||
|
|
||||||
class GoogleVertexAPI(DescEngine):
|
class GoogleVertexAPI(DescEngine):
|
||||||
def __init__(self, project_id: str, location: str, gac_path: str) -> None:
|
def __init__(self, project_id: str, location: str, gac_path: str) -> None:
|
||||||
self.project_id = project_id
|
self.project_id = project_id
|
||||||
|
|
|
@ -24,24 +24,33 @@ HOST1 = "http://127.0.0.1:8001"
|
||||||
def testHTML():
|
def testHTML():
|
||||||
print("TESTING HTML")
|
print("TESTING HTML")
|
||||||
|
|
||||||
|
# alt: alttext.AltTextHTML = alttext.AltTextHTML(
|
||||||
|
# # descengine.ReplicateAPI(keys.ReplicateEricKey(), "blip"),
|
||||||
|
# # ocrengine.Tesseract(),
|
||||||
|
# # langengine.PrivateGPT(HOST1),
|
||||||
|
# )
|
||||||
|
|
||||||
|
# alt: alttext.AltTextHTML = alttext.AltTextHTML(
|
||||||
|
# descengine.BlipLocal("C:/Users/dacru/Desktop/Codebase/ALT/image-captioning"),
|
||||||
|
# options={"version": 1},
|
||||||
|
# )
|
||||||
|
|
||||||
alt: alttext.AltTextHTML = alttext.AltTextHTML(
|
alt: alttext.AltTextHTML = alttext.AltTextHTML(
|
||||||
# descengine.GoogleVertexAPI(
|
descengine.BlipLocal("C:/Users/dacru/Desktop/Codebase/ALT/image-captioning"),
|
||||||
# keys.VertexProject(), keys.VertexRegion(), keys.VertexGAC()
|
ocrengine.Tesseract(),
|
||||||
# ),
|
langengine.PrivateGPT(HOST1),
|
||||||
descengine.ReplicateAPI(keys.ReplicateEricKey(), "blip"),
|
|
||||||
# ocrengine.Tesseract(),
|
|
||||||
# langengine.PrivateGPT(HOST1),
|
|
||||||
options={"version": 1},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
alt.parseFile(HTML_HUNTING)
|
alt.parseFile(HTML_HUNTING)
|
||||||
imgs = alt.getAllImgs()
|
imgs = alt.getAllImgs()
|
||||||
# src = imgs[5].attrs["src"]
|
src = imgs[4].attrs["src"]
|
||||||
# print(src)
|
print(src)
|
||||||
|
print(alt.genAltText(src))
|
||||||
|
|
||||||
# desc = alt.genDesc(alt.getImgData(src), src)
|
# desc = alt.genDesc(alt.getImgData(src), src)
|
||||||
# print(desc)
|
# print(desc)
|
||||||
associations = alt.genAltAssociations(imgs)
|
# associations = alt.genAltAssociations(imgs)
|
||||||
print(associations)
|
# print(associations)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue