diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 1e7e8816..06eccb77 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,5 +1,6 @@ { "name": "Ubuntu", + "runArgs": ["--name", "${localEnv:USER}_dev_container"], "build": { "dockerfile": "Dockerfile" }, diff --git a/README.md b/README.md index e509da2a..a6152983 100644 --- a/README.md +++ b/README.md @@ -228,7 +228,7 @@ Higher level code for interacting with the ORM is available in `aana.repository. Here are the environment variables that can be used to configure the Aana SDK: - TMP_DATA_DIR: The directory to store temporary data. Default: `/tmp/aana`. - NUM_WORKERS: The number of request workers. Default: `2`. -- DB_CONFIG: The database configuration in the format `{"datastore_type": "sqlite", "datastore_config": {"path": "/path/to/sqlite.db"}}`. Currently only SQLite and PostgreSQL are supported. Default: `{"datastore_type": "sqlite", "datastore_config": {"path": "/var/lib/aana_data"}}`. +- DB_CONFIG: The database configuration in the format `{"datastore_type": "sqlite", "datastore_config": {"path": "/path/to/sqlite.db"}}`. Currently, only SQLite and PostgreSQL are supported. Default: `{"datastore_type": "sqlite", "datastore_config": {"path": "/var/lib/aana_data"}}`. - USE_DEPLOYMENT_CACHE (testing only): If set to `true`, the tests will use the deployment cache to avoid downloading the models and running the deployments. Default: `false`. - SAVE_DEPLOYMENT_CACHE (testing only): If set to `true`, the tests will save the deployment cache after running the deployments. Default: `false`. - HF_HUB_ENABLE_HF_TRANSFER: If set to `1`, the HuggingFace Transformers will use the HF Transfer library to download the models from HuggingFace Hub to speed up the process. Recommended to always set to it `1`. Default: `0`. diff --git a/aana/core/models/stream.py b/aana/core/models/stream.py new file mode 100644 index 00000000..05e47d10 --- /dev/null +++ b/aana/core/models/stream.py @@ -0,0 +1,58 @@ +import uuid +from typing import Annotated + +from pydantic import ( + AfterValidator, + AnyUrl, + BaseModel, + ConfigDict, + Field, +) + +from aana.core.models.media import MediaId + + +class StreamInput(BaseModel): + """A video stream input. + + The 'url' must be provided. + + Attributes: + media_id (MediaId): the ID of the video stream. If not provided, it will be generated automatically. + url (AnyUrl): the URL of the video stream + channel_number (int): the desired channel of stream to be processed + extract_fps (float): the number of frames to extract per second + """ + + url: Annotated[ + AnyUrl, + Field(description="The URL of the video stream."), + AfterValidator(lambda x: str(x)), + ] + channel_number: int = Field( + default=0, + ge=0, + description=("the desired channel of stream"), + ) + + extract_fps: float = Field( + default=3.0, + gt=0.0, + description=( + "The number of frames to extract per second. " + "Can be smaller than 1. For example, 0.5 means 1 frame every 2 seconds." + ), + ) + + media_id: MediaId = Field( + default_factory=lambda: str(uuid.uuid4()), + description="The ID of the video. If not provided, it will be generated automatically.", + ) + + model_config = ConfigDict( + json_schema_extra={ + "description": ("A video Stream. \n" "The 'url' must be provided. \n") + }, + validate_assignment=True, + file_upload=False, + ) diff --git a/aana/exceptions/io.py b/aana/exceptions/io.py index 306254fd..694033eb 100644 --- a/aana/exceptions/io.py +++ b/aana/exceptions/io.py @@ -102,3 +102,26 @@ class VideoReadingException(VideoException): """ pass + + +class StreamReadingException(BaseException): + """Exception raised when there is an error reading an stream. + + Attributes: + stream (Stream): the stream that caused the exception + """ + + def __init__(self, url: str, msg: str = ""): + """Initialize the exception. + + Args: + url (str): the URL of the stream that caused the exception + msg (str): the error message + """ + super().__init__(url=url) + self.url = url + self.msg = msg + + def __reduce__(self): + """Used for pickling.""" + return (self.__class__, (self.url, self.msg)) diff --git a/aana/integrations/external/av.py b/aana/integrations/external/av.py index d09e1445..31268536 100644 --- a/aana/integrations/external/av.py +++ b/aana/integrations/external/av.py @@ -4,11 +4,23 @@ import wave from collections.abc import Generator from pathlib import Path +from typing import TypedDict import av import numpy as np from aana.core.libraries.audio import AbstractAudioLibrary +from aana.core.models.image import Image +from aana.core.models.stream import StreamInput +from aana.exceptions.io import StreamReadingException + + +class FramesDict(TypedDict): + """Represents a set of frames with ids, timestamps.""" + + frames: list[Image] + timestamps: list[float] + frame_ids: list[int] def load_audio(file: Path | None, sample_rate: int = 16000) -> bytes: @@ -120,6 +132,70 @@ def resample_frames(frames: Generator, resampler) -> Generator: yield from resampler.resample(frame) +def fetch_stream_frames( + stream_input: StreamInput, batch_size: int = 2 +) -> Generator[FramesDict, None, None]: + """Generate frames from a video using decord. + + Args: + stream_input (StreamInput): the video stream to fetch frames from + batch_size (int): the number of frames to yield at each iteration + Yields: + FramesDict: a dictionary containing the extracted frames, frame ids, timestamps, and duration for each batch + """ + stream_url = stream_input.url + channel = stream_input.channel_number + extraction_fps = stream_input.extract_fps + + try: + stream_container = av.open(stream_url) + except Exception as e: + raise StreamReadingException(stream_url) from e + + available_streams = [s for s in stream_container.streams if s.type == "video"] + + # Check the stream channel be valid + if len(available_streams) == 0 or channel >= len(available_streams): + raise StreamReadingException( + stream_url, + msg=f"selected channel does not exist: {channel + 1} from {len(available_streams)}", + ) + video_stream = available_streams[channel] + + avg_rate = float(video_stream.average_rate) + + if extraction_fps > avg_rate: + extraction_fps = avg_rate + + frame_rate = int(avg_rate / extraction_fps) + + # read frames from the stream + frame_number = 0 + batch_frames = [] + batch_timestamps = [] + num_batches = 0 + + for packet in stream_container.demux(video_stream): + for frame in packet.decode(): + if frame_number % frame_rate == 0: + img = Image(numpy=frame.to_rgb().to_ndarray()) + packet_timestamp = float(frame.pts * frame.time_base) # in seconds + batch_frames.append(img) + batch_timestamps.append(packet_timestamp) + frame_number += 1 + if len(batch_frames) == batch_size: + num_batches += 1 + yield FramesDict( + frames=batch_frames, + frame_ids=list( + range(num_batches * batch_size, (num_batches + 1) * batch_size) + ), + timestamps=batch_timestamps, + ) + batch_frames = [] + batch_timestamps = [] + + class pyAVWrapper(AbstractAudioLibrary): """Class for audio handling using PyAV library.""" diff --git a/aana/projects/process_stream/__init__.py b/aana/projects/process_stream/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aana/projects/process_stream/app.py b/aana/projects/process_stream/app.py new file mode 100644 index 00000000..c591719a --- /dev/null +++ b/aana/projects/process_stream/app.py @@ -0,0 +1,48 @@ +import argparse + +from aana.configs.deployments import hf_blip2_opt_2_7b_deployment +from aana.projects.process_stream.endpoints import ( + CaptionStreamEndpoint, +) +from aana.sdk import AanaSDK + +deployments = [ + { + "name": "captioning_deployment", + "instance": hf_blip2_opt_2_7b_deployment, + }, +] + +endpoints = [ + { + "name": "caption_live_stream", + "path": "/stream/caption_stream", + "summary": "Process a live stream and return the captions", + "endpoint_cls": CaptionStreamEndpoint, + }, +] + +if __name__ == "__main__": + """Runs the application.""" + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument("--port", type=int, default=8000) + arg_parser.add_argument("--host", type=str, default="127.0.0.1") + args = arg_parser.parse_args() + + aana_app = AanaSDK(port=args.port, host=args.host, show_logs=True) + + for deployment in deployments: + aana_app.register_deployment( + name=deployment["name"], + instance=deployment["instance"], + ) + + for endpoint in endpoints: + aana_app.register_endpoint( + name=endpoint["name"], + path=endpoint["path"], + summary=endpoint["summary"], + endpoint_cls=endpoint["endpoint_cls"], + ) + + aana_app.deploy(blocking=True) diff --git a/aana/projects/process_stream/const.py b/aana/projects/process_stream/const.py new file mode 100644 index 00000000..46d039ed --- /dev/null +++ b/aana/projects/process_stream/const.py @@ -0,0 +1 @@ +captioning_model_name = "hf_blip2_opt_2_7b" diff --git a/aana/projects/process_stream/endpoints.py b/aana/projects/process_stream/endpoints.py new file mode 100644 index 00000000..9963cc8b --- /dev/null +++ b/aana/projects/process_stream/endpoints.py @@ -0,0 +1,46 @@ +from collections.abc import AsyncGenerator +from typing import Annotated, TypedDict + +from pydantic import Field + +from aana.api.api_generation import Endpoint +from aana.core.models.stream import StreamInput +from aana.deployments.aana_deployment_handle import AanaDeploymentHandle +from aana.integrations.external.av import fetch_stream_frames +from aana.processors.remote import run_remote + + +class CaptionStreamOutput(TypedDict): + """The output of the transcribe video endpoint.""" + + captions: Annotated[list[str], Field(..., description="Captions")] + timestamps: Annotated[ + list[float], Field(..., description="Timestamps for each caption in seconds") + ] + + +class CaptionStreamEndpoint(Endpoint): + """Transcribe video in chunks endpoint.""" + + async def initialize(self): + """Initialize the endpoint.""" + self.captioning_handle = await AanaDeploymentHandle.create( + "captioning_deployment" + ) + + async def run( + self, + stream: StreamInput, + ) -> AsyncGenerator[CaptionStreamOutput, None]: + """Transcribe video in chunks.""" + async for frames_dict in run_remote(fetch_stream_frames)( + stream_input=stream, batch_size=2 + ): + captioning_output = await self.captioning_handle.generate_batch( + images=frames_dict["frames"] + ) + + yield { + "captions": captioning_output["captions"], + "timestamps": frames_dict["timestamps"], + } diff --git a/aana/tests/deployments/test_text_generation_deployment.py b/aana/tests/deployments/test_text_generation_deployment.py index 12495d9e..2eb7af4e 100644 --- a/aana/tests/deployments/test_text_generation_deployment.py +++ b/aana/tests/deployments/test_text_generation_deployment.py @@ -21,10 +21,11 @@ def get_expected_output(name): "and business magnate who is best known for his innovative companies in" ) elif name == "meta_llama3_8b_instruct_deployment": - return (" Elon Musk is a South African-born entrepreneur, inventor," - "and business magnate. He is the CEO and CTO of SpaceX, " - "CEO and product architect of Tesla" - ) + return ( + " Elon Musk is a South African-born entrepreneur, inventor," + "and business magnate. He is the CEO and CTO of SpaceX, " + "CEO and product architect of Tesla" + ) elif name == "hf_phi3_mini_4k_instruct_text_gen_deployment": return ( "Elon Musk is a prominent entrepreneur and business magnate known for " @@ -33,6 +34,7 @@ def get_expected_output(name): else: raise ValueError(f"Unknown deployment name: {name}") # noqa: TRY003 + def get_expected_chat_output(name): """Gets expected output for a given text_generation model.""" if name == "vllm_llama2_7b_chat_deployment": @@ -41,8 +43,7 @@ def get_expected_chat_output(name): "and business magnate who is best known for his innovative companies in" ) elif name == "meta_llama3_8b_instruct_deployment": - return ("Elon Musk is a South African-born entrepreneur, inventor, and business magnate. He is best known for his ambitious goals to revolutionize the transportation, energy" - ) + return "Elon Musk is a South African-born entrepreneur, inventor, and business magnate. He is best known for his ambitious goals to revolutionize the transportation, energy" elif name == "hf_phi3_mini_4k_instruct_text_gen_deployment": return ( "Elon Musk is a prominent entrepreneur and business magnate known for " diff --git a/aana/tests/units/test_frame_extraction.py b/aana/tests/units/test_frame_extraction.py index 03d5b3dd..8d7150fb 100644 --- a/aana/tests/units/test_frame_extraction.py +++ b/aana/tests/units/test_frame_extraction.py @@ -4,8 +4,10 @@ import pytest from aana.core.models.image import Image +from aana.core.models.stream import StreamInput from aana.core.models.video import Video, VideoParams -from aana.exceptions.io import VideoReadingException +from aana.exceptions.io import StreamReadingException, VideoReadingException +from aana.integrations.external.av import fetch_stream_frames from aana.integrations.external.decord import extract_frames, generate_frames @@ -89,3 +91,65 @@ def test_extract_frames_failure(): invalid_video = Video(path=path) params = VideoParams(extract_fps=1.0, fast_mode_enabled=False) extract_frames(video=invalid_video, params=params) + + +@pytest.mark.parametrize( + "mode, url, channel_number, extract_fps", + [ + ( + "hls", + "https://live-par-2-cdn-alt.livepush.io/live/bigbuckbunnyclip/index.m3u8", + 0, + 3, + ), + ( + "dash", + "https://live-par-2-cdn-alt.livepush.io/live/bigbuckbunnyclip/index.mpd", + 0, + 3, + ), + ( + "mp4", + "https://live-par-2-abr.livepush.io/vod/bigbuckbunnyclip.mp4", + 0, + 3, + ), + ], +) +def test_fetch_stream_frames(mode, url, channel_number, extract_fps): + """Test fetch_stream_frames. + + fetch_stream_frames is a generator function that yields a dictionary + containing the frames, timestamps and frame_ids of the stream. + """ + stream_input = StreamInput( + url=url, channel_number=channel_number, extract_fps=extract_fps + ) + gen_frame = fetch_stream_frames(stream_input, batch_size=1) + total_frames = 0 + for result in gen_frame: + assert "frames" in result + assert "frame_ids" in result + assert "timestamps" in result + assert isinstance(result["frames"], list) + assert isinstance(result["frame_ids"], list) + assert isinstance(result["timestamps"], list) + + assert isinstance(result["frames"][0], Image) + assert len(result["frames"]) == 1 # batch_size = 1 + assert len(result["timestamps"]) == 1 # batch_size = 1 + + total_frames += 1 + if total_frames > 10: + return + print(f"{mode} is supported") + + +def test_fetch_stream_frames_failure(): + """Test that frames cannot be extracted from a youtube video.""" + url = "https://www.youtube.com/watch?v=T98dnE2vPdY" + stream_input = StreamInput(url=url, channel_number=0, extract_fps=3) + with pytest.raises(StreamReadingException): + gen_frame = fetch_stream_frames(stream_input, batch_size=1) + for _ in gen_frame: + return diff --git a/aana/tests/units/test_stream_input.py b/aana/tests/units/test_stream_input.py new file mode 100644 index 00000000..ab4cefe0 --- /dev/null +++ b/aana/tests/units/test_stream_input.py @@ -0,0 +1,37 @@ +# ruff: noqa: S101 +import pytest +from pydantic import ValidationError + +from aana.core.models.stream import StreamInput + + +def test_new_stream_input_success(): + """Test that StreamInput can be created successfully.""" + stream_input = StreamInput(url="http://example.com/stream.m3u8") + assert stream_input.url == "http://example.com/stream.m3u8" + + +def test_stream_input_invalid_media_id(): + """Test that StreamInput can't be created if media_id is invalid.""" + with pytest.raises(ValidationError): + StreamInput(url="http://example.com/stream.m3u8", media_id="") + + +@pytest.mark.parametrize( + "url, extract_fps", + [("http://example.com/stream.m3u8", 0), ("http://example.com/stream.m3u8", -1)], +) +def test_stream_input_invalid_extract_fps(url, extract_fps): + """Test that StreamInput can't be created if extract_fps is invalid.""" + with pytest.raises(ValidationError): + StreamInput(url=url, extract_fps=extract_fps) + + +@pytest.mark.parametrize( + "url, channel_number", + [("http://example.com/stream.m3u8", -1), ("http://example.com/stream.m3u8", 0.3)], +) +def test_stream_input_invalid_channel(url, channel_number): + """Test that StreamInput can't be created if channel number is invalid.""" + with pytest.raises(ValidationError): + StreamInput(url=url, channel_number=channel_number)