diff --git a/README.md b/README.md index fb9cb9a..0499564 100644 --- a/README.md +++ b/README.md @@ -206,4 +206,4 @@ If no callback_url is provided: { "message": "No Message Callback, Passing" } -``` +``` \ No newline at end of file diff --git a/lib/model/model.py b/lib/model/model.py index 094bca2..eb99df7 100644 --- a/lib/model/model.py +++ b/lib/model/model.py @@ -9,6 +9,9 @@ from lib import schemas class Model(ABC): BATCH_SIZE = 1 + def __init__(self): + self.model_name = os.environ.get("MODEL_NAME") + def get_tempfile_for_url(self, url: str) -> str: """ Loads a file based on specified URL into a named tempfile. diff --git a/lib/queue/worker.py b/lib/queue/worker.py index 7e7a4e1..3b0f2df 100644 --- a/lib/queue/worker.py +++ b/lib/queue/worker.py @@ -52,7 +52,7 @@ def safely_respond(self, model: Model) -> List[schemas.Message]: if messages_with_queues: logger.debug(f"About to respond to: ({messages_with_queues})") try: - responses = model.respond([schemas.Message(**json.loads(message.body)) for message, queue in messages_with_queues]) + responses = model.respond([schemas.Message(**{**json.loads(message.body), **{"model_name": self.model.model_name}}) for message, queue in messages_with_queues]) except Exception as e: logger.error(e) self.delete_messages(messages_with_queues) diff --git a/lib/schemas.py b/lib/schemas.py index 991aca5..985e594 100644 --- a/lib/schemas.py +++ b/lib/schemas.py @@ -1,5 +1,5 @@ from typing import Any, List, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, root_validator # Output hash values can be of different types. class GenericItem(BaseModel): @@ -8,9 +8,23 @@ class GenericItem(BaseModel): url: Optional[str] = None text: Optional[str] = None raw: Optional[dict] = {} + +class MediaItem(BaseModel): hash_value: Optional[Any] = None + +class VideoItem(MediaItem): bucket: Optional[str] = None outfile: Optional[str] = None class Message(BaseModel): - body: Union[GenericItem] \ No newline at end of file + body: Union[MediaItem, VideoItem] + model_name: str + @root_validator(pre=True) + def set_body(cls, values): + body = values.get("body") + model_name = values.get("model_name") + if model_name == "video__Model": + values["body"] = VideoItem(**values["body"]) + if model_name in ["audio__Model", "image__Model", "fptg__Model", "indian_sbert__Model", "mean_tokens__Model", "fasttext__Model"] + values["body"] = MediaItem(**values["body"]) + return values