diff --git a/src/alttext/alttext.py b/src/alttext/alttext.py index 7bec3d7..b96aba5 100644 --- a/src/alttext/alttext.py +++ b/src/alttext/alttext.py @@ -6,9 +6,9 @@ import bs4 import ebooklib from ebooklib import epub -from .descengine import DescEngine -from .ocrengine import OCREngine -from .langengine import LangEngine +from .descengine.descengine import DescEngine +from .ocrengine.ocrengine import OCREngine +from .langengine.langengine import LangEngine DEFOPTIONS = { @@ -523,7 +523,7 @@ class AltTextHTML(AltText): try: text = elem.text.strip() while text == "": - elem = elem.previous_element + elem = elem.next_element text = elem.text.strip() context[1] = text except: @@ -564,7 +564,6 @@ class AltTextHTML(AltText): if self.options["withContext"]: context = self.getContext(self.getImg(src)) desc = self.genDesc(imgdata, src, context) - chars = "" if self.ocrEngine != None: chars = self.genChars(imgdata, src).strip() diff --git a/src/alttext/descengine.py b/src/alttext/descengine.py deleted file mode 100644 index 26f1284..0000000 --- a/src/alttext/descengine.py +++ /dev/null @@ -1,133 +0,0 @@ -from abc import ABC, abstractmethod -import base64 -import os -import shutil -import subprocess -import uuid - -import replicate -import vertexai -from vertexai.vision_models import ImageTextModel, Image - - -### DESCENGINE CLASSES -class DescEngine(ABC): - @abstractmethod - def genDesc(self, imgData: bytes, src: str, context: str = None) -> str: - """Generates description for an image. - - Args: - imgData (bytes): Image data in bytes. - src (str): Source of image. - context (str, optional): Context of image. See getContext in alttext for more information. Defaults to None. - - Returns: - str: _description_ - """ - pass - - -### IMPLEMENTATIONS -REPLICATE_MODELS = { - "blip": "salesforce/blip:2e1dddc8621f72155f24cf2e0adbde548458d3cab9f00c0139eea840d0ac4746", - "clip_prefix_caption": "rmokady/clip_prefix_caption:9a34a6339872a03f45236f114321fb51fc7aa8269d38ae0ce5334969981e4cd8", - "clip-caption-reward": "j-min/clip-caption-reward:de37751f75135f7ebbe62548e27d6740d5155dfefdf6447db35c9865253d7e06", - "img2prompt": "methexis-inc/img2prompt:50adaf2d3ad20a6f911a8a9e3ccf777b263b8596fbd2c8fc26e8888f8a0edbb5", - "minigpt4": "daanelson/minigpt-4:b96a2f33cc8e4b0aa23eacfce731b9c41a7d9466d9ed4e167375587b54db9423", - "image-captioning-with-visual-attention": "nohamoamary/image-captioning-with-visual-attention:9bb60a6baa58801aa7cd4c4fafc95fcf1531bf59b84962aff5a718f4d1f58986", -} - - -class ReplicateAPI(DescEngine): - def __init__(self, key: str, model: str = "blip") -> None: - self.__setKey(key) - self.__setModel(model) - return None - - def __getModel(self) -> str: - return self.model - - def __setModel(self, modelName: str) -> str: - if modelName not in REPLICATE_MODELS: - raise Exception( - f"{modelName} is not a valid model. Please choose from {list(REPLICATE_MODELS.keys())}" - ) - self.model = REPLICATE_MODELS[modelName] - return self.model - - def __getKey(self) -> str: - return self.key - - def __setKey(self, key: str) -> str: - self.key = key - os.environ["REPLICATE_API_TOKEN"] = key - return self.key - - def genDesc(self, imgData: bytes, src: str, context: str = None) -> str: - base64_utf8_str = base64.b64encode(imgData).decode("utf-8") - model = self.__getModel() - ext = src.split(".")[-1] - prompt = "Create alternative-text for this image." - if context != None: - prompt = f"Create alternative-text for this image given the following context...\n{context}" - - dataurl = f"data:image/{ext};base64,{base64_utf8_str}" - output = replicate.run(model, input={"image": dataurl, "prompt": prompt}) - return output - - -class BlipLocal(DescEngine): - def __init__(self, path: str) -> None: - self.__setPath(path) - return None - - def __setPath(self, path: str) -> str: - self.path = path - return self.path - - def genDesc(self, imgData: bytes, src: str, context: str = None) -> str: - folderName = uuid.uuid4() - ext = src.split(".")[-1] - os.makedirs(f"{self.path}/{folderName}") - open(f"{self.path}/{folderName}/image.{ext}", "wb+").write(imgData) - subprocess.call( - f"python {self.path}/inference.py -i ./{folderName} --batch 1 --gpu 0", - cwd=f"{self.path}", - ) - desc = open(f"{self.path}/{folderName}/0_captions.txt", "r").read() - shutil.rmtree(f"{self.path}/{folderName}") - desc = desc.split(",") - return desc[1] - - -class GoogleVertexAPI(DescEngine): - def __init__(self, project_id: str, location: str, gac_path: str) -> None: - self.project_id = project_id - self.location = location - vertexai.init(project=self.project_id, location=self.location) - - self.gac_path = gac_path - os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.gac_path - return None - - def __setProject(self, project_id: str): - self.project_id = project_id - vertexai.init(project=self.project_id, location=self.location) - - def __setLocation(self, location: str): - self.location = location - vertexai.init(project=self.project_id, location=self.location) - - def __setGAC(self, gac_path: str): - self.gac_path = gac_path - os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.gac_path - - def genDesc(self, imgData: bytes, src: str, context: str = None) -> str: - model = ImageTextModel.from_pretrained("imagetext@001") - source_image = Image(imgData) - captions = model.get_captions( - image=source_image, - number_of_results=1, - language="en", - ) - return captions[0] diff --git a/src/alttext/descengine/bliplocal.py b/src/alttext/descengine/bliplocal.py new file mode 100644 index 0000000..5f6417d --- /dev/null +++ b/src/alttext/descengine/bliplocal.py @@ -0,0 +1,29 @@ +import os +import shutil +import subprocess +import uuid + +from .descengine import DescEngine + +class BlipLocal(DescEngine): + def __init__(self, path: str) -> None: + self.__setPath(path) + return None + + def __setPath(self, path: str) -> str: + self.path = path + return self.path + + def genDesc(self, imgData: bytes, src: str, context: str = None) -> str: + folderName = uuid.uuid4() + ext = src.split(".")[-1] + os.makedirs(f"{self.path}/{folderName}") + open(f"{self.path}/{folderName}/image.{ext}", "wb+").write(imgData) + subprocess.call( + f"py inference.py -i ./{folderName} --batch 1 --gpu 0", + cwd=f"{self.path}", + ) + desc = open(f"{self.path}/{folderName}/0_captions.txt", "r").read() + shutil.rmtree(f"{self.path}/{folderName}") + desc = desc.split(",") + return desc[1] \ No newline at end of file diff --git a/src/alttext/descengine/descengine.py b/src/alttext/descengine/descengine.py new file mode 100644 index 0000000..94b4f37 --- /dev/null +++ b/src/alttext/descengine/descengine.py @@ -0,0 +1,17 @@ +from abc import ABC, abstractmethod + +### DESCENGINE CLASSES +class DescEngine(ABC): + @abstractmethod + def genDesc(self, imgData: bytes, src: str, context: str = None) -> str: + """Generates description for an image. + + Args: + imgData (bytes): Image data in bytes. + src (str): Source of image. + context (str, optional): Context of image. See getContext in alttext for more information. Defaults to None. + + Returns: + str: _description_ + """ + pass diff --git a/src/alttext/descengine/googlevertexapi.py b/src/alttext/descengine/googlevertexapi.py new file mode 100644 index 0000000..881583d --- /dev/null +++ b/src/alttext/descengine/googlevertexapi.py @@ -0,0 +1,37 @@ +import os +import vertexai +from vertexai.vision_models import ImageTextModel, Image + +from .descengine import DescEngine + +class GoogleVertexAPI(DescEngine): + def __init__(self, project_id: str, location: str, gac_path: str) -> None: + self.project_id = project_id + self.location = location + vertexai.init(project=self.project_id, location=self.location) + + self.gac_path = gac_path + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.gac_path + return None + + def __setProject(self, project_id: str): + self.project_id = project_id + vertexai.init(project=self.project_id, location=self.location) + + def __setLocation(self, location: str): + self.location = location + vertexai.init(project=self.project_id, location=self.location) + + def __setGAC(self, gac_path: str): + self.gac_path = gac_path + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.gac_path + + def genDesc(self, imgData: bytes, src: str, context: str = None) -> str: + model = ImageTextModel.from_pretrained("imagetext@001") + source_image = Image(imgData) + captions = model.get_captions( + image=source_image, + number_of_results=1, + language="en", + ) + return captions[0] \ No newline at end of file diff --git a/src/alttext/descengine/replicateapi.py b/src/alttext/descengine/replicateapi.py new file mode 100644 index 0000000..6483360 --- /dev/null +++ b/src/alttext/descengine/replicateapi.py @@ -0,0 +1,51 @@ +import replicate +import base64 +import os + +from .descengine import DescEngine + +REPLICATE_MODELS = { + "blip": "salesforce/blip:2e1dddc8621f72155f24cf2e0adbde548458d3cab9f00c0139eea840d0ac4746", + "clip_prefix_caption": "rmokady/clip_prefix_caption:9a34a6339872a03f45236f114321fb51fc7aa8269d38ae0ce5334969981e4cd8", + "clip-caption-reward": "j-min/clip-caption-reward:de37751f75135f7ebbe62548e27d6740d5155dfefdf6447db35c9865253d7e06", + "img2prompt": "methexis-inc/img2prompt:50adaf2d3ad20a6f911a8a9e3ccf777b263b8596fbd2c8fc26e8888f8a0edbb5", + "minigpt4": "daanelson/minigpt-4:b96a2f33cc8e4b0aa23eacfce731b9c41a7d9466d9ed4e167375587b54db9423", + "image-captioning-with-visual-attention": "nohamoamary/image-captioning-with-visual-attention:9bb60a6baa58801aa7cd4c4fafc95fcf1531bf59b84962aff5a718f4d1f58986", +} + +class ReplicateAPI(DescEngine): + def __init__(self, key: str, model: str = "blip") -> None: + self.__setKey(key) + self.__setModel(model) + return None + + def __getModel(self) -> str: + return self.model + + def __setModel(self, modelName: str) -> str: + if modelName not in REPLICATE_MODELS: + raise Exception( + f"{modelName} is not a valid model. Please choose from {list(REPLICATE_MODELS.keys())}" + ) + self.model = REPLICATE_MODELS[modelName] + return self.model + + def __getKey(self) -> str: + return self.key + + def __setKey(self, key: str) -> str: + self.key = key + os.environ["REPLICATE_API_TOKEN"] = key + return self.key + + def genDesc(self, imgData: bytes, src: str, context: str = None) -> str: + base64_utf8_str = base64.b64encode(imgData).decode("utf-8") + model = self.__getModel() + ext = src.split(".")[-1] + prompt = "Create alternative-text for this image." + if context != None: + prompt = f"Create alternative-text for this image given the following context...\n{context}" + + dataurl = f"data:image/{ext};base64,{base64_utf8_str}" + output = replicate.run(model, input={"image": dataurl, "prompt": prompt}) + return output \ No newline at end of file diff --git a/src/alttext/langengine/langengine.py b/src/alttext/langengine/langengine.py new file mode 100644 index 0000000..599fd12 --- /dev/null +++ b/src/alttext/langengine/langengine.py @@ -0,0 +1,102 @@ +from abc import ABC, abstractmethod + +class LangEngine(ABC): + @abstractmethod + def _completion(self, prompt: str) -> str: + """Sends message to language model and returns its response. + + Args: + prompt (str): Prompt to send to language model. + + Returns: + str: Response from language model. + """ + pass + + @abstractmethod + def refineDesc(self, description: str) -> str: + """Refines description of an image. + Used in V1 Dataflow. + + Args: + description (str): Description of an image. + + Returns: + str: Refinement of description. + """ + pass + + @abstractmethod + def refineOCR(self, chars: str) -> str: + """Refines characters found in an image. + Used in V1 Dataflow. + + Args: + chars (str): Characters found in an image. + + Returns: + str: Refinement of characters. + """ + pass + + @abstractmethod + def genPrompt(self, desc: str, chars: str, context: list[str], caption: str) -> str: + """Generates prompt to send to language model in V2 Dataflow. + + Args: + desc (str): Description of an image. + chars (str): Characters found in an image. + context (list[str]): Context of an image. See getContext in alttext for more information. + caption (str): Caption of an image. + + Returns: + str: Prompt to send to language model. + """ + pass + + @abstractmethod + def refineAlt( + self, + desc: str, + chars: str = None, + context: list[str] = None, + caption: str = None, + ) -> str: + """Generates alt-text for an image. + Used in V2 Dataflow. + + Args: + desc (str): Description of an image. + chars (str, optional): Characters found in an image. Defaults to None. + context (list[str], optional): Context of an image. See getContext in alttext for more information. Defaults to None. + caption (str, optional): Caption of an image. Defaults to None. + + Returns: + str: Alt-text for an image. + """ + pass + + @abstractmethod + def ingest(self, filename: str, binary) -> bool: + """Ingests a file into the language model. + + Args: + filename (str): Name of file. + binary (_type_): Data of file. + + Returns: + bool: True if successful. + """ + pass + + @abstractmethod + def degest(self, filename: str) -> bool: + """Removes a file from the language model. + + Args: + filename (str): Name of file. + + Returns: + bool: True if successful. + """ + pass \ No newline at end of file diff --git a/src/alttext/langengine.py b/src/alttext/langengine/privategpt.py similarity index 66% rename from src/alttext/langengine.py rename to src/alttext/langengine/privategpt.py index 3e68470..4da27d4 100644 --- a/src/alttext/langengine.py +++ b/src/alttext/langengine/privategpt.py @@ -1,111 +1,7 @@ -from abc import ABC, abstractmethod import requests +from .langengine import LangEngine -### LANGENGINE CLASSES -class LangEngine(ABC): - @abstractmethod - def _completion(self, prompt: str) -> str: - """Sends message to language model and returns its response. - - Args: - prompt (str): Prompt to send to language model. - - Returns: - str: Response from language model. - """ - pass - - @abstractmethod - def refineDesc(self, description: str) -> str: - """Refines description of an image. - Used in V1 Dataflow. - - Args: - description (str): Description of an image. - - Returns: - str: Refinement of description. - """ - pass - - @abstractmethod - def refineOCR(self, chars: str) -> str: - """Refines characters found in an image. - Used in V1 Dataflow. - - Args: - chars (str): Characters found in an image. - - Returns: - str: Refinement of characters. - """ - pass - - @abstractmethod - def genPrompt(self, desc: str, chars: str, context: list[str], caption: str) -> str: - """Generates prompt to send to language model in V2 Dataflow. - - Args: - desc (str): Description of an image. - chars (str): Characters found in an image. - context (list[str]): Context of an image. See getContext in alttext for more information. - caption (str): Caption of an image. - - Returns: - str: Prompt to send to language model. - """ - pass - - @abstractmethod - def refineAlt( - self, - desc: str, - chars: str = None, - context: list[str] = None, - caption: str = None, - ) -> str: - """Generates alt-text for an image. - Used in V2 Dataflow. - - Args: - desc (str): Description of an image. - chars (str, optional): Characters found in an image. Defaults to None. - context (list[str], optional): Context of an image. See getContext in alttext for more information. Defaults to None. - caption (str, optional): Caption of an image. Defaults to None. - - Returns: - str: Alt-text for an image. - """ - pass - - @abstractmethod - def ingest(self, filename: str, binary) -> bool: - """Ingests a file into the language model. - - Args: - filename (str): Name of file. - binary (_type_): Data of file. - - Returns: - bool: True if successful. - """ - pass - - @abstractmethod - def degest(self, filename: str) -> bool: - """Removes a file from the language model. - - Args: - filename (str): Name of file. - - Returns: - bool: True if successful. - """ - pass - - -### IMPLEMENTATIONS class PrivateGPT(LangEngine): def __init__(self, host) -> None: self.host = host diff --git a/src/alttext/ocrengine.py b/src/alttext/ocrengine.py deleted file mode 100644 index b334722..0000000 --- a/src/alttext/ocrengine.py +++ /dev/null @@ -1,38 +0,0 @@ -from abc import ABC, abstractmethod -from PIL import Image -from io import BytesIO - -import pytesseract - - -### OCRENGINE ABSTRACT -class OCREngine(ABC): - @abstractmethod - def genChars(self, imgData: bytes, src: str, context: str = None) -> str: - """Searches for characters in an image. - - Args: - imgData (bytes): Image data in bytes. - src (str): Image source. - context (str, optional): Context of an image. See getContext in alttext for more information. Defaults to None. - - Returns: - str: Characters found in an image. - """ - pass - - -### IMPLEMENTATIONS -class Tesseract(OCREngine): - def __init__(self) -> None: - self.customPath = None - return None - - def _setTesseract(self, path: str) -> bool: - self.customPath = path - pytesseract.pytesseract.tesseract_cmd = path - return True - - def genChars(self, imgData: bytes, src: str, context: str = None) -> str: - image = Image.open(BytesIO(imgData)) - return pytesseract.image_to_string(image) diff --git a/src/alttext/ocrengine/ocrengine.py b/src/alttext/ocrengine/ocrengine.py new file mode 100644 index 0000000..f1dcdc5 --- /dev/null +++ b/src/alttext/ocrengine/ocrengine.py @@ -0,0 +1,16 @@ +from abc import ABC, abstractmethod + +class OCREngine(ABC): + @abstractmethod + def genChars(self, imgData: bytes, src: str, context: str = None) -> str: + """Searches for characters in an image. + + Args: + imgData (bytes): Image data in bytes. + src (str): Image source. + context (str, optional): Context of an image. See getContext in alttext for more information. Defaults to None. + + Returns: + str: Characters found in an image. + """ + pass \ No newline at end of file diff --git a/src/alttext/ocrengine/tesseract.py b/src/alttext/ocrengine/tesseract.py new file mode 100644 index 0000000..f42c6c8 --- /dev/null +++ b/src/alttext/ocrengine/tesseract.py @@ -0,0 +1,19 @@ +from PIL import Image +from io import BytesIO +import pytesseract + +from .ocrengine import OCREngine + +class Tesseract(OCREngine): + def __init__(self) -> None: + self.customPath = None + return None + + def _setTesseract(self, path: str) -> bool: + self.customPath = path + pytesseract.pytesseract.tesseract_cmd = path + return True + + def genChars(self, imgData: bytes, src: str, context: str = None) -> str: + image = Image.open(BytesIO(imgData)) + return pytesseract.image_to_string(image) \ No newline at end of file diff --git a/tests/test.py b/tests/test.py index 8a304ec..0e89ed3 100644 --- a/tests/test.py +++ b/tests/test.py @@ -2,9 +2,9 @@ import sys sys.path.append("../") import src.alttext.alttext as alttext -import src.alttext.descengine as descengine -import src.alttext.ocrengine as ocrengine -import src.alttext.langengine as langengine +from src.alttext.descengine.bliplocal import BlipLocal +from src.alttext.ocrengine.tesseract import Tesseract +from src.alttext.langengine.privategpt import PrivateGPT import keys # HTML BOOK FILEPATHS @@ -23,22 +23,10 @@ HOST1 = "http://127.0.0.1:8001" def testHTML(): print("TESTING HTML") - - # alt: alttext.AltTextHTML = alttext.AltTextHTML( - # # descengine.ReplicateAPI(keys.ReplicateEricKey(), "blip"), - # # ocrengine.Tesseract(), - # # langengine.PrivateGPT(HOST1), - # ) - - # alt: alttext.AltTextHTML = alttext.AltTextHTML( - # descengine.BlipLocal("C:/Users/dacru/Desktop/Codebase/ALT/image-captioning"), - # options={"version": 1}, - # ) - alt: alttext.AltTextHTML = alttext.AltTextHTML( - descengine.BlipLocal("C:/Users/dacru/Desktop/Codebase/ALT/image-captioning"), - ocrengine.Tesseract(), - langengine.PrivateGPT(HOST1), + BlipLocal("C:/Users/dacru/Desktop/ALT/image-captioning"), + Tesseract(), + PrivateGPT(HOST1), ) alt.parseFile(HTML_HUNTING)