Skip to content

Commit

Permalink
Changed methods of BLIP2 deployment
Browse files Browse the repository at this point in the history
to be consistent with VLLM deployment
  • Loading branch information
Aleksandr Movchan committed Nov 10, 2023
1 parent bc1c609 commit 8e93fde
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 14 deletions.
2 changes: 1 addition & 1 deletion aana/configs/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@
"name": "hf_blip2_opt_2_7b",
"type": "ray_deployment",
"deployment_name": "hf_blip2_deployment_opt_2_7b",
"method": "generate_captions",
"method": "generate_batch",
"inputs": [
{
"name": "images",
Expand Down
48 changes: 39 additions & 9 deletions aana/deployments/hf_blip2_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ class CaptioningOutput(TypedDict):
"""
The output of the captioning model.
Attributes:
caption (str): the caption
"""

caption: str


class CaptioningBatchOutput(TypedDict):
"""
The output of the captioning model.
Attributes:
captions (List[str]): the list of captions
"""
Expand Down Expand Up @@ -82,11 +93,11 @@ async def apply_config(self, config: Dict[str, Any]):
# and process them in parallel
self.batch_size = config_obj.batch_size
self.num_processing_threads = config_obj.num_processing_threads
# The actual inference is done in _generate_captions()
# The actual inference is done in _generate()
# We use lambda because BatchProcessor expects dict as input
# and we use **kwargs to unpack the dict into named arguments for _generate_captions()
# and we use **kwargs to unpack the dict into named arguments for _generate()
self.batch_processor = BatchProcessor(
process_batch=lambda request: self._generate_captions(**request),
process_batch=lambda request: self._generate(**request),
batch_size=self.batch_size,
num_threads=self.num_processing_threads,
)
Expand All @@ -103,25 +114,44 @@ async def apply_config(self, config: Dict[str, Any]):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)

async def generate_captions(self, **kwargs) -> CaptioningOutput:
async def generate(self, image: Image) -> CaptioningOutput:
"""
Generate captions for the given image.
Args:
image (Image): the image
Returns:
CaptioningOutput: the dictionary with one key "captions"
and the list of captions for the image as value
Raises:
InferenceException: if the inference fails
"""
captions: CaptioningBatchOutput = await self.batch_processor.process(
{"images": [image]}
)
return CaptioningOutput(caption=captions["captions"][0])

async def generate_batch(self, **kwargs) -> CaptioningBatchOutput:
"""
Generate captions for the given images.
Args:
images (List[Image]): the images
Returns:
CaptioningOutput: the dictionary with one key "captions"
CaptioningBatchOutput: the dictionary with one key "captions"
and the list of captions for the images as value
Raises:
InferenceException: if the inference fails
"""
# Call the batch processor to process the requests
# The actual inference is done in _generate_captions()
# The actual inference is done in _generate()
return await self.batch_processor.process(kwargs)

def _generate_captions(self, images: List[Image]) -> CaptioningOutput:
def _generate(self, images: List[Image]) -> CaptioningBatchOutput:
"""
Generate captions for the given images.
Expand All @@ -131,7 +161,7 @@ def _generate_captions(self, images: List[Image]) -> CaptioningOutput:
images (List[Image]): the images
Returns:
CaptioningOutput: the dictionary with one key "captions"
CaptioningBatchOutput: the dictionary with one key "captions"
and the list of captions for the images as value
Raises:
Expand All @@ -152,6 +182,6 @@ def _generate_captions(self, images: List[Image]) -> CaptioningOutput:
generated_texts = [
generated_text.strip() for generated_text in generated_texts
]
return CaptioningOutput(captions=generated_texts)
return CaptioningBatchOutput(captions=generated_texts)
except Exception as e:
raise InferenceException(self.model_id) from e
16 changes: 12 additions & 4 deletions aana/tests/deployments/test_hf_blip2_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,30 @@ def ray_setup(deployment):

@pytest.mark.skipif(not is_gpu_available(), reason="GPU is not available")
@pytest.mark.asyncio
async def test_hf_blip2_deployments():
@pytest.mark.parametrize(
"image_name, expected_text",
[("Starry_Night.jpeg", "the starry night by vincent van gogh")],
)
async def test_hf_blip2_deployments(image_name, expected_text):
for name, deployment in deployments.items():
# skip if not a VLLM deployment
if deployment.name != "HFBlip2Deployment":
continue

handle = ray_setup(deployment)

path = resources.path("aana.tests.files.images", "Starry_Night.jpeg")
path = resources.path("aana.tests.files.images", image_name)
image = Image(path=path, save_on_disk=False)

output = await handle.generate.remote(image=image)
caption = output["caption"]
compare_texts(expected_text, caption)

images = [image] * 8

output = await handle.generate_captions.remote(images=images)
output = await handle.generate_batch.remote(images=images)
captions = output["captions"]

assert len(captions) == 8
for caption in captions:
compare_texts("the starry night by vincent van gogh", caption)
compare_texts(expected_text, caption)

0 comments on commit 8e93fde

Please sign in to comment.