Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hr/streaming #87

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"name": "Ubuntu",
"runArgs": ["--name", "${localEnv:USER}_dev_container"],
"build": {
"dockerfile": "Dockerfile"
},
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ Higher level code for interacting with the ORM is available in `aana.repository.
Here are the environment variables that can be used to configure the Aana SDK:
- TMP_DATA_DIR: The directory to store temporary data. Default: `/tmp/aana`.
- NUM_WORKERS: The number of request workers. Default: `2`.
- DB_CONFIG: The database configuration in the format `{"datastore_type": "sqlite", "datastore_config": {"path": "/path/to/sqlite.db"}}`. Currently only SQLite and PostgreSQL are supported. Default: `{"datastore_type": "sqlite", "datastore_config": {"path": "/var/lib/aana_data"}}`.
- DB_CONFIG: The database configuration in the format `{"datastore_type": "sqlite", "datastore_config": {"path": "/path/to/sqlite.db"}}`. Currently, only SQLite and PostgreSQL are supported. Default: `{"datastore_type": "sqlite", "datastore_config": {"path": "/var/lib/aana_data"}}`.
- USE_DEPLOYMENT_CACHE (testing only): If set to `true`, the tests will use the deployment cache to avoid downloading the models and running the deployments. Default: `false`.
- SAVE_DEPLOYMENT_CACHE (testing only): If set to `true`, the tests will save the deployment cache after running the deployments. Default: `false`.
- HF_HUB_ENABLE_HF_TRANSFER: If set to `1`, the HuggingFace Transformers will use the HF Transfer library to download the models from HuggingFace Hub to speed up the process. Recommended to always set to it `1`. Default: `0`.
58 changes: 58 additions & 0 deletions aana/core/models/stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import uuid
from typing import Annotated

from pydantic import (
AfterValidator,
AnyUrl,
BaseModel,
ConfigDict,
Field,
)

from aana.core.models.media import MediaId


class StreamInput(BaseModel):
"""A video stream input.

The 'url' must be provided.

Attributes:
media_id (MediaId): the ID of the video stream. If not provided, it will be generated automatically.
url (AnyUrl): the URL of the video stream
channel_number (int): the desired channel of stream to be processed
extract_fps (float): the number of frames to extract per second
"""

url: Annotated[
AnyUrl,
Field(description="The URL of the video stream."),
AfterValidator(lambda x: str(x)),
]
channel_number: int = Field(
default=0,
ge=0,
description=("the desired channel of stream"),
)

extract_fps: float = Field(
default=3.0,
gt=0.0,
description=(
"The number of frames to extract per second. "
"Can be smaller than 1. For example, 0.5 means 1 frame every 2 seconds."
),
)

media_id: MediaId = Field(
default_factory=lambda: str(uuid.uuid4()),
description="The ID of the video. If not provided, it will be generated automatically.",
)

model_config = ConfigDict(
json_schema_extra={
"description": ("A video Stream. \n" "The 'url' must be provided. \n")
},
validate_assignment=True,
file_upload=False,
)
23 changes: 23 additions & 0 deletions aana/exceptions/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,26 @@ class VideoReadingException(VideoException):
"""

pass


class StreamReadingException(BaseException):
"""Exception raised when there is an error reading an stream.

Attributes:
stream (Stream): the stream that caused the exception
"""

def __init__(self, url: str, msg: str = ""):
"""Initialize the exception.

Args:
url (str): the URL of the stream that caused the exception
msg (str): the error message
"""
super().__init__(url=url)
self.url = url
self.msg = msg

def __reduce__(self):
"""Used for pickling."""
return (self.__class__, (self.url, self.msg))
76 changes: 76 additions & 0 deletions aana/integrations/external/av.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,23 @@
import wave
from collections.abc import Generator
from pathlib import Path
from typing import TypedDict

import av
import numpy as np

from aana.core.libraries.audio import AbstractAudioLibrary
from aana.core.models.image import Image
from aana.core.models.stream import StreamInput
from aana.exceptions.io import StreamReadingException


class FramesDict(TypedDict):
"""Represents a set of frames with ids, timestamps."""

frames: list[Image]
timestamps: list[float]
frame_ids: list[int]


def load_audio(file: Path | None, sample_rate: int = 16000) -> bytes:
Expand Down Expand Up @@ -120,6 +132,70 @@ def resample_frames(frames: Generator, resampler) -> Generator:
yield from resampler.resample(frame)


def fetch_stream_frames(
stream_input: StreamInput, batch_size: int = 2
) -> Generator[FramesDict, None, None]:
"""Generate frames from a video using decord.

Args:
stream_input (StreamInput): the video stream to fetch frames from
batch_size (int): the number of frames to yield at each iteration
Yields:
FramesDict: a dictionary containing the extracted frames, frame ids, timestamps, and duration for each batch
"""
stream_url = stream_input.url
channel = stream_input.channel_number
extraction_fps = stream_input.extract_fps

try:
stream_container = av.open(stream_url)
except Exception as e:
raise StreamReadingException(stream_url) from e

available_streams = [s for s in stream_container.streams if s.type == "video"]

# Check the stream channel be valid
if len(available_streams) == 0 or channel >= len(available_streams):
raise StreamReadingException(
stream_url,
msg=f"selected channel does not exist: {channel + 1} from {len(available_streams)}",
)
video_stream = available_streams[channel]

avg_rate = float(video_stream.average_rate)

if extraction_fps > avg_rate:
extraction_fps = avg_rate

frame_rate = int(avg_rate / extraction_fps)

# read frames from the stream
frame_number = 0
batch_frames = []
batch_timestamps = []
num_batches = 0

for packet in stream_container.demux(video_stream):
for frame in packet.decode():
if frame_number % frame_rate == 0:
img = Image(numpy=frame.to_rgb().to_ndarray())
packet_timestamp = float(frame.pts * frame.time_base) # in seconds
batch_frames.append(img)
batch_timestamps.append(packet_timestamp)
frame_number += 1
if len(batch_frames) == batch_size:
num_batches += 1
yield FramesDict(
frames=batch_frames,
frame_ids=list(
range(num_batches * batch_size, (num_batches + 1) * batch_size)
),
timestamps=batch_timestamps,
)
batch_frames = []
batch_timestamps = []


class pyAVWrapper(AbstractAudioLibrary):
"""Class for audio handling using PyAV library."""

Expand Down
Empty file.
48 changes: 48 additions & 0 deletions aana/projects/process_stream/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import argparse

from aana.configs.deployments import hf_blip2_opt_2_7b_deployment
from aana.projects.process_stream.endpoints import (
CaptionStreamEndpoint,
)
from aana.sdk import AanaSDK

deployments = [
{
"name": "captioning_deployment",
"instance": hf_blip2_opt_2_7b_deployment,
},
]

endpoints = [
{
"name": "caption_live_stream",
"path": "/stream/caption_stream",
"summary": "Process a live stream and return the captions",
"endpoint_cls": CaptionStreamEndpoint,
},
]

if __name__ == "__main__":
"""Runs the application."""
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("--port", type=int, default=8000)
arg_parser.add_argument("--host", type=str, default="127.0.0.1")
args = arg_parser.parse_args()

aana_app = AanaSDK(port=args.port, host=args.host, show_logs=True)

for deployment in deployments:
aana_app.register_deployment(
name=deployment["name"],
instance=deployment["instance"],
)

for endpoint in endpoints:
aana_app.register_endpoint(
name=endpoint["name"],
path=endpoint["path"],
summary=endpoint["summary"],
endpoint_cls=endpoint["endpoint_cls"],
)

aana_app.deploy(blocking=True)
1 change: 1 addition & 0 deletions aana/projects/process_stream/const.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
captioning_model_name = "hf_blip2_opt_2_7b"
46 changes: 46 additions & 0 deletions aana/projects/process_stream/endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from collections.abc import AsyncGenerator
from typing import Annotated, TypedDict

from pydantic import Field

from aana.api.api_generation import Endpoint
from aana.core.models.stream import StreamInput
from aana.deployments.aana_deployment_handle import AanaDeploymentHandle
from aana.integrations.external.av import fetch_stream_frames
from aana.processors.remote import run_remote


class CaptionStreamOutput(TypedDict):
"""The output of the transcribe video endpoint."""

captions: Annotated[list[str], Field(..., description="Captions")]
timestamps: Annotated[
list[float], Field(..., description="Timestamps for each caption in seconds")
]


class CaptionStreamEndpoint(Endpoint):
"""Transcribe video in chunks endpoint."""

async def initialize(self):
"""Initialize the endpoint."""
self.captioning_handle = await AanaDeploymentHandle.create(
"captioning_deployment"
)

async def run(
self,
stream: StreamInput,
) -> AsyncGenerator[CaptionStreamOutput, None]:
"""Transcribe video in chunks."""
async for frames_dict in run_remote(fetch_stream_frames)(
stream_input=stream, batch_size=2
):
captioning_output = await self.captioning_handle.generate_batch(
images=frames_dict["frames"]
)

yield {
"captions": captioning_output["captions"],
"timestamps": frames_dict["timestamps"],
}
13 changes: 7 additions & 6 deletions aana/tests/deployments/test_text_generation_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ def get_expected_output(name):
"and business magnate who is best known for his innovative companies in"
)
elif name == "meta_llama3_8b_instruct_deployment":
return (" Elon Musk is a South African-born entrepreneur, inventor,"
"and business magnate. He is the CEO and CTO of SpaceX, "
"CEO and product architect of Tesla"
)
return (
" Elon Musk is a South African-born entrepreneur, inventor,"
"and business magnate. He is the CEO and CTO of SpaceX, "
"CEO and product architect of Tesla"
)
elif name == "hf_phi3_mini_4k_instruct_text_gen_deployment":
return (
"Elon Musk is a prominent entrepreneur and business magnate known for "
Expand All @@ -33,6 +34,7 @@ def get_expected_output(name):
else:
raise ValueError(f"Unknown deployment name: {name}") # noqa: TRY003


def get_expected_chat_output(name):
"""Gets expected output for a given text_generation model."""
if name == "vllm_llama2_7b_chat_deployment":
Expand All @@ -41,8 +43,7 @@ def get_expected_chat_output(name):
"and business magnate who is best known for his innovative companies in"
)
elif name == "meta_llama3_8b_instruct_deployment":
return ("Elon Musk is a South African-born entrepreneur, inventor, and business magnate. He is best known for his ambitious goals to revolutionize the transportation, energy"
)
return "Elon Musk is a South African-born entrepreneur, inventor, and business magnate. He is best known for his ambitious goals to revolutionize the transportation, energy"
elif name == "hf_phi3_mini_4k_instruct_text_gen_deployment":
return (
"Elon Musk is a prominent entrepreneur and business magnate known for "
Expand Down
Loading
Loading