diff --git a/aana/configs/pipeline.py b/aana/configs/pipeline.py index a8e39454..81ca9857 100644 --- a/aana/configs/pipeline.py +++ b/aana/configs/pipeline.py @@ -155,7 +155,7 @@ "name": "hf_blip2_opt_2_7b", "type": "ray_deployment", "deployment_name": "hf_blip2_deployment_opt_2_7b", - "method": "generate_captions", + "method": "generate_batch", "inputs": [ { "name": "images", diff --git a/aana/deployments/hf_blip2_deployment.py b/aana/deployments/hf_blip2_deployment.py index 800bfe14..f5b5c677 100644 --- a/aana/deployments/hf_blip2_deployment.py +++ b/aana/deployments/hf_blip2_deployment.py @@ -52,6 +52,17 @@ class CaptioningOutput(TypedDict): """ The output of the captioning model. + Attributes: + caption (str): the caption + """ + + caption: str + + +class CaptioningBatchOutput(TypedDict): + """ + The output of the captioning model. + Attributes: captions (List[str]): the list of captions """ @@ -82,11 +93,11 @@ async def apply_config(self, config: Dict[str, Any]): # and process them in parallel self.batch_size = config_obj.batch_size self.num_processing_threads = config_obj.num_processing_threads - # The actual inference is done in _generate_captions() + # The actual inference is done in _generate() # We use lambda because BatchProcessor expects dict as input - # and we use **kwargs to unpack the dict into named arguments for _generate_captions() + # and we use **kwargs to unpack the dict into named arguments for _generate() self.batch_processor = BatchProcessor( - process_batch=lambda request: self._generate_captions(**request), + process_batch=lambda request: self._generate(**request), batch_size=self.batch_size, num_threads=self.num_processing_threads, ) @@ -103,7 +114,26 @@ async def apply_config(self, config: Dict[str, Any]): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(self.device) - async def generate_captions(self, **kwargs) -> CaptioningOutput: + async def generate(self, image: Image) -> CaptioningOutput: + """ + Generate captions for the given image. + + Args: + image (Image): the image + + Returns: + CaptioningOutput: the dictionary with one key "captions" + and the list of captions for the image as value + + Raises: + InferenceException: if the inference fails + """ + captions: CaptioningBatchOutput = await self.batch_processor.process( + {"images": [image]} + ) + return CaptioningOutput(caption=captions["captions"][0]) + + async def generate_batch(self, **kwargs) -> CaptioningBatchOutput: """ Generate captions for the given images. @@ -111,17 +141,17 @@ async def generate_captions(self, **kwargs) -> CaptioningOutput: images (List[Image]): the images Returns: - CaptioningOutput: the dictionary with one key "captions" + CaptioningBatchOutput: the dictionary with one key "captions" and the list of captions for the images as value Raises: InferenceException: if the inference fails """ # Call the batch processor to process the requests - # The actual inference is done in _generate_captions() + # The actual inference is done in _generate() return await self.batch_processor.process(kwargs) - def _generate_captions(self, images: List[Image]) -> CaptioningOutput: + def _generate(self, images: List[Image]) -> CaptioningBatchOutput: """ Generate captions for the given images. @@ -131,7 +161,7 @@ def _generate_captions(self, images: List[Image]) -> CaptioningOutput: images (List[Image]): the images Returns: - CaptioningOutput: the dictionary with one key "captions" + CaptioningBatchOutput: the dictionary with one key "captions" and the list of captions for the images as value Raises: @@ -152,6 +182,6 @@ def _generate_captions(self, images: List[Image]) -> CaptioningOutput: generated_texts = [ generated_text.strip() for generated_text in generated_texts ] - return CaptioningOutput(captions=generated_texts) + return CaptioningBatchOutput(captions=generated_texts) except Exception as e: raise InferenceException(self.model_id) from e diff --git a/aana/tests/deployments/test_hf_blip2_deployment.py b/aana/tests/deployments/test_hf_blip2_deployment.py index 3c501967..5fc94d11 100644 --- a/aana/tests/deployments/test_hf_blip2_deployment.py +++ b/aana/tests/deployments/test_hf_blip2_deployment.py @@ -21,7 +21,11 @@ def ray_setup(deployment): @pytest.mark.skipif(not is_gpu_available(), reason="GPU is not available") @pytest.mark.asyncio -async def test_hf_blip2_deployments(): +@pytest.mark.parametrize( + "image_name, expected_text", + [("Starry_Night.jpeg", "the starry night by vincent van gogh")], +) +async def test_hf_blip2_deployments(image_name, expected_text): for name, deployment in deployments.items(): # skip if not a VLLM deployment if deployment.name != "HFBlip2Deployment": @@ -29,14 +33,18 @@ async def test_hf_blip2_deployments(): handle = ray_setup(deployment) - path = resources.path("aana.tests.files.images", "Starry_Night.jpeg") + path = resources.path("aana.tests.files.images", image_name) image = Image(path=path, save_on_disk=False) + output = await handle.generate.remote(image=image) + caption = output["caption"] + compare_texts(expected_text, caption) + images = [image] * 8 - output = await handle.generate_captions.remote(images=images) + output = await handle.generate_batch.remote(images=images) captions = output["captions"] assert len(captions) == 8 for caption in captions: - compare_texts("the starry night by vincent van gogh", caption) + compare_texts(expected_text, caption)