Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

YouTube URL Enhancements and Video Verification after Download #26

Merged
merged 4 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 1 addition & 15 deletions aana/api/api_generation.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import traceback
from collections.abc import AsyncGenerator, Callable
from dataclasses import dataclass
from enum import Enum
from typing import Any, Optional

from fastapi import FastAPI, File, Form, UploadFile
from fastapi.responses import StreamingResponse
from mobius_pipeline.exceptions import BaseException
from mobius_pipeline.node.socket import Socket
from mobius_pipeline.pipeline.pipeline import Pipeline
from pydantic import BaseModel, Field, ValidationError, create_model, parse_raw_as
from ray.exceptions import RayTaskError

from aana.api.app import custom_exception_handler
from aana.api.responses import AanaJSONResponse
Expand Down Expand Up @@ -389,19 +386,8 @@ async def generator_wrapper() -> AsyncGenerator[bytes, None]:
):
output = self.process_output(output)
yield AanaJSONResponse(content=output).body
except RayTaskError as e:
yield custom_exception_handler(None, e).body
except BaseException as e:
yield custom_exception_handler(None, e)
except Exception as e:
error = e.__class__.__name__
stacktrace = traceback.format_exc()
yield AanaJSONResponse(
status_code=400,
content=ExceptionResponseModel(
error=error, message=str(e), stacktrace=stacktrace
).dict(),
).body
yield custom_exception_handler(None, e).body

return StreamingResponse(
generator_wrapper(), media_type="application/json"
Expand Down
10 changes: 3 additions & 7 deletions aana/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ async def validation_exception_handler(request: Request, exc: ValidationError):
)


def custom_exception_handler(
request: Request | None, exc_raw: BaseException | RayTaskError
):
def custom_exception_handler(request: Request | None, exc_raw: Exception):
"""This handler is used to handle custom exceptions raised in the application.

BaseException is the base exception for all the exceptions
Expand All @@ -43,7 +41,7 @@ def custom_exception_handler(

Args:
request (Request): The request object
exc_raw (Union[BaseException, RayTaskError]): The exception raised
exc_raw (Exception): The exception raised

Returns:
JSONResponse: JSON response with the error details. The response contains the following fields:
Expand All @@ -60,8 +58,6 @@ def custom_exception_handler(
stacktrace = str(exc_raw)
# get the original exception
exc: BaseException = exc_raw.cause
if not isinstance(exc, BaseException):
raise TypeError(exc)
else:
# if it is not a RayTaskError
# then we need to get the stack trace
Expand All @@ -70,7 +66,7 @@ def custom_exception_handler(
# get the data from the exception
# can be used to return additional info
# like image path, url, model name etc.
data = exc.get_data()
data = exc.get_data() if isinstance(exc, BaseException) else {}
# get the name of the class of the exception
# can be used to identify the type of the error
error = exc.__class__.__name__
Expand Down
3 changes: 3 additions & 0 deletions aana/api/responses.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
from typing import Any

import orjson
Expand Down Expand Up @@ -27,6 +28,8 @@ def json_serializer_default(obj: Any) -> Any:
"""
if isinstance(obj, BaseModel):
return obj.dict()
if isinstance(obj, Path):
return str(obj)
raise TypeError


Expand Down
1 change: 0 additions & 1 deletion aana/configs/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ class Settings(BaseSettings):
"""A pydantic model for SDK settings."""

tmp_data_dir: Path = Path("/tmp/aana_data") # noqa: S108
youtube_video_dir = tmp_data_dir / "youtube_videos"
image_dir = tmp_data_dir / "images"
video_dir = tmp_data_dir / "videos"

Expand Down
6 changes: 4 additions & 2 deletions aana/exceptions/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,20 @@ class DownloadException(BaseException):
url (str): the URL of the file that caused the exception
"""

def __init__(self, url: str):
def __init__(self, url: str, msg: str = ""):
"""Initialize the exception.

Args:
url (str): the URL of the file that caused the exception
msg (str): the error message
"""
super().__init__(url=url)
self.url = url
self.msg = msg

def __reduce__(self):
"""Used for pickling."""
return (self.__class__, (self.url,))
return (self.__class__, (self.url, self.msg))


class VideoException(BaseException):
Expand Down
Empty file added aana/models/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion aana/models/core/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def save(self):
raise ValueError( # noqa: TRY003
"At least one of 'path', 'url', or 'content' must be provided."
)
self.path = file_path
self.is_saved = True

def save_from_bytes(self, file_path: Path, content: bytes):
Expand All @@ -89,6 +88,7 @@ def save_from_bytes(self, file_path: Path, content: bytes):
content (bytes): the content of the media
"""
file_path.write_bytes(content)
self.path = file_path

def save_from_content(self, file_path: Path):
"""Save the media from the content.
Expand Down
40 changes: 38 additions & 2 deletions aana/models/core/video.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import hashlib
import hashlib # noqa: I001
evanderiel marked this conversation as resolved.
Show resolved Hide resolved
from dataclasses import dataclass
from pathlib import Path
import torch, decord # noqa: F401 # See https://github.com/dmlc/decord/issues/263

from aana.configs.settings import settings
from aana.exceptions.general import VideoReadingException
from aana.models.core.media import Media


Expand All @@ -28,7 +30,12 @@ class Video(Media):
media_dir: Path | None = settings.video_dir

def validate(self):
"""Validate the video."""
"""Validate the video.

Raises:
ValueError: if none of 'path', 'url', or 'content' is provided
VideoReadingException: if the video is not valid
"""
# validate the parent class
super().validate()

Expand All @@ -44,6 +51,35 @@ def validate(self):
"At least one of 'path', 'url' or 'content' must be provided."
)

# check that the video is valid
if self.path and not self.is_video():
raise VideoReadingException(video=self)

def is_video(self) -> bool:
"""Checks if it's a valid video."""
if not self.path:
return False
try:
decord.VideoReader(str(self.path))
except Exception:
return False
return True

def save_from_url(self, file_path):
"""Save the media from the URL.

Args:
file_path (Path): the path to save the media to

Raises:
DownloadError: if the media can't be downloaded
movchan74 marked this conversation as resolved.
Show resolved Hide resolved
VideoReadingException: if the media is not a valid video
"""
super().save_from_url(file_path)
# check that the file is a video
if not self.is_video():
raise VideoReadingException(video=self)

def __repr__(self) -> str:
"""Get the representation of the video.

Expand Down
35 changes: 0 additions & 35 deletions aana/models/core/video_source.py

This file was deleted.

2 changes: 1 addition & 1 deletion aana/models/pydantic/asr_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from faster_whisper.transcribe import (
Word as WhisperWord,
)
from pydantic import BaseModel, Field

from aana.models.pydantic.base import BaseListModel
from aana.models.pydantic.time_interval import TimeInterval
from pydantic import BaseModel, Field


class AsrWord(BaseModel):
Expand Down
3 changes: 2 additions & 1 deletion aana/models/pydantic/captions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from types import MappingProxyType

from aana.models.pydantic.base import BaseListModel
from pydantic import BaseModel

from aana.models.pydantic.base import BaseListModel


class Caption(BaseModel):
"""A model for a caption."""
Expand Down
4 changes: 2 additions & 2 deletions aana/models/pydantic/image_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from types import MappingProxyType

import numpy as np
from pydantic import BaseModel, Field, ValidationError, root_validator, validator
from pydantic.error_wrappers import ErrorWrapper

from aana.models.core.image import Image
from aana.models.pydantic.base import BaseListModel
from pydantic import BaseModel, Field, ValidationError, root_validator, validator
from pydantic.error_wrappers import ErrorWrapper


class ImageInput(BaseModel):
Expand Down
5 changes: 3 additions & 2 deletions aana/models/pydantic/video_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from pathlib import Path
from types import MappingProxyType

from aana.models.core.video import Video
from aana.models.pydantic.base import BaseListModel
from pydantic import BaseModel, Field, ValidationError, root_validator, validator
from pydantic.error_wrappers import ErrorWrapper

from aana.models.core.video import Video
from aana.models.pydantic.base import BaseListModel


class VideoInput(BaseModel):
"""A video input.
Expand Down
11 changes: 4 additions & 7 deletions aana/tests/test_chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,10 @@ def test_chat_template_custom():
prompt = apply_chat_template(
tokenizer, dialog, "llama2"
) # Apply custom chat template "llama2"
assert ( # noqa: S101
prompt
== (
"<s>[INST] <<SYS>>\\nYou are a friendly chatbot who always responds in the style "
"of a pirate\\n<</SYS>>\\n\\nHow many helicopters can a human eat in one sitting? "
"[/INST] I don't know, how many? </s><s>[INST] One, but only if they're really hungry! [/INST]"
)
assert prompt == (
"<s>[INST] <<SYS>>\\nYou are a friendly chatbot who always responds in the style "
"of a pirate\\n<</SYS>>\\n\\nHow many helicopters can a human eat in one sitting? "
"[/INST] I don't know, how many? </s><s>[INST] One, but only if they're really hungry! [/INST]"
)


Expand Down
4 changes: 2 additions & 2 deletions aana/tests/test_frame_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_extract_frames_failure():
# image file instead of video file will create Video object
# but will fail in extract_frames_decord
path = resources.path("aana.tests.files.images", "Starry_Night.jpeg")
invalid_video = Video(path=path)
params = VideoParams(extract_fps=1.0, fast_mode_enabled=False)
with pytest.raises(VideoReadingException):
invalid_video = Video(path=path)
params = VideoParams(extract_fps=1.0, fast_mode_enabled=False)
extract_frames_decord(video=invalid_video, params=params)
Loading