Skip to content

Commit

Permalink
feat: add asset dependency (#34)
Browse files Browse the repository at this point in the history
* feat: init label row args

* feat: let download_asset also download audio files

* feat: allow getting frame from image sequence

* feat: add dep_asset to tasks

* feat: add asset dependency

* fix: use temporary directory for data storage

* fix: dependency issue for fastapi that didn't know how to resolve label row
  • Loading branch information
frederik-encord authored Dec 17, 2024
1 parent 96a3dbf commit 56597d4
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 58 deletions.
14 changes: 14 additions & 0 deletions encord_agents/core/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,20 @@ class LabelRowMetadataIncludeArgs(BaseModel):
include_all_label_branches: bool = False


class LabelRowInitialiseLabelsArgs(BaseModel):
"""
Arguments used to specify how to initialise labels via the SDK.
The arguments are passed to `LabelRowV2.initialise_labels`.
"""

include_object_feature_hashes: set[str] | None = None
include_classification_feature_hashes: set[str] | None = None
include_reviews: bool = False
overwrite: bool = False
include_signed_url: bool = False


class FrameData(BaseModel):
"""
Holds the data sent from the Encord Label Editor at the time of triggering the agent.
Expand Down
88 changes: 53 additions & 35 deletions encord_agents/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
from contextlib import contextmanager
from functools import lru_cache
from pathlib import Path
from typing import Any, Generator
from tempfile import TemporaryDirectory
from typing import Any, Generator, cast

import cv2
import requests
from encord.constants.enums import DataType
from encord.objects.ontology_labels_impl import LabelRowV2
from encord.user_client import EncordUserClient

from encord_agents.core.data_model import FrameData, LabelRowMetadataIncludeArgs
from encord_agents.core.data_model import FrameData, LabelRowInitialiseLabelsArgs, LabelRowMetadataIncludeArgs
from encord_agents.core.settings import Settings

from .video import get_frame
Expand All @@ -31,7 +32,9 @@ def get_user_client() -> EncordUserClient:


def get_initialised_label_row(
frame_data: FrameData, include_args: LabelRowMetadataIncludeArgs | None = None
frame_data: FrameData,
include_args: LabelRowMetadataIncludeArgs | None = None,
init_args: LabelRowInitialiseLabelsArgs | None = None,
) -> LabelRowV2:
"""
Get an initialised label row from the frame_data information.
Expand All @@ -49,18 +52,19 @@ def get_initialised_label_row(
user_client = get_user_client()
project = user_client.get_project(str(frame_data.project_hash))
include_args = include_args or LabelRowMetadataIncludeArgs()
init_args = init_args or LabelRowInitialiseLabelsArgs()
matched_lrs = project.list_label_rows_v2(data_hashes=[frame_data.data_hash], **include_args.model_dump())
num_matches = len(matched_lrs)
if num_matches > 1:
raise Exception(f"Non unique match: matched {num_matches} label rows!")
elif num_matches == 0:
raise Exception("No label rows were matched!")
lr = matched_lrs.pop()
lr.initialise_labels(include_signed_url=True)
lr.initialise_labels(**init_args.model_dump())
return lr


def _guess_file_suffix(url: str, lr: LabelRowV2) -> str:
def _guess_file_suffix(url: str, lr: LabelRowV2) -> tuple[str, str]:
"""
Best effort attempt to guess file suffix given a url and label row.
Expand All @@ -75,8 +79,8 @@ def _guess_file_suffix(url: str, lr: LabelRowV2) -> str:
- lr: the associated label row
Returns:
A file suffix that can be used to store the file. For example, ".jpg" or ".mp4"
A file type and suffix that can be used to store the file.
For example, ("image", ".jpg") or ("video", ".mp4").
"""
fallback_mimetype = "video/mp4" if lr.data_type == DataType.VIDEO else "image/png"
mimetype, _ = next(
Expand All @@ -95,21 +99,23 @@ def _guess_file_suffix(url: str, lr: LabelRowV2) -> str:

file_type, suffix = mimetype.split("/")[:2]

if file_type == "video" and lr.data_type != DataType.VIDEO:
if (file_type == "audio" and lr.data_type != DataType.AUDIO) or (
file_type == "video" and lr.data_type != DataType.VIDEO
):
raise ValueError(f"Mimetype {mimetype} and lr data type {lr.data_type} did not match")
elif file_type == "image" and lr.data_type not in {
DataType.IMG_GROUP,
DataType.IMAGE,
}:
raise ValueError(f"Mimetype {mimetype} and lr data type {lr.data_type} did not match")
elif file_type not in {"image", "video"}:
raise ValueError("File type not video or image")
elif file_type not in {"image", "video", "audio"}:
raise ValueError("File type not audio, video, or image")

return f".{suffix}"
return file_type, f".{suffix}"


@contextmanager
def download_asset(lr: LabelRowV2, frame: int | None) -> Generator[Path, None, None]:
def download_asset(lr: LabelRowV2, frame: int | None = None) -> Generator[Path, None, None]:
"""
Download the asset associated to a label row to disk.
Expand All @@ -136,35 +142,47 @@ def download_asset(lr: LabelRowV2, frame: int | None) -> Generator[Path, None, N
The file path for the requested asset.
"""
video_item, images_list = lr._project_client.get_data(lr.data_hash, get_signed_url=True)
if lr.data_type in [DataType.VIDEO, DataType.IMAGE] and video_item:
url = video_item["file_link"]
elif lr.data_type == DataType.IMG_GROUP and images_list:
url: str | None = None
if lr.data_link is not None and lr.data_link[:5] == "https":
url = lr.data_link
elif lr.backing_item_uuid is not None:
storage_item = get_user_client().get_storage_item(lr.backing_item_uuid, sign_url=True)
url = storage_item.get_signed_url()

# Fallback for native image groups (they don't have a url)
is_image_sequence = lr.data_type == DataType.IMG_GROUP
if url is None:
is_image_sequence = False
_, images_list = lr._project_client.get_data(lr.data_hash, get_signed_url=True)
if images_list is None:
raise ValueError("Image list should not be none for image groups.")
if frame is None:
raise NotImplementedError(
"Downloading entire image group is not supported. Please contact Encord at [email protected] for help or submit a PR with an implementation."
)
url = images_list[frame]["file_link"]
else:
raise ValueError(f"Couldn't load asset of type {lr.data_type}")
image = images_list[frame]
url = cast(str | None, image.file_link)

if url is None:
raise ValueError("Failed to get a signed url for the asset")

response = requests.get(url)
response.raise_for_status()

suffix = _guess_file_suffix(url, lr)
file_path = Path(lr.data_hash).with_suffix(suffix)
with open(file_path, "wb") as f:
f.write(response.content)

files_to_unlink = [file_path]
if lr.data_type == DataType.VIDEO and frame is not None: # Get that exact frame
frame_content = get_frame(file_path, frame)
frame_file = file_path.with_name(f"{file_path.name}_{frame}").with_suffix(".png")
cv2.imwrite(frame_file.as_posix(), frame_content)
files_to_unlink.append(frame_file)
file_path = frame_file
try:
with TemporaryDirectory() as dir_name:
dir_path = Path(dir_name)

_, suffix = _guess_file_suffix(url, lr)
file_path = dir_path / f"{lr.data_hash}{suffix}"
with open(file_path, "wb") as f:
for chunk in response.iter_content(chunk_size=4096):
if chunk:
f.write(chunk)

if (lr.data_type == DataType.VIDEO or is_image_sequence) and frame is not None: # Get that exact frame
frame_content = get_frame(file_path, frame)
frame_file = file_path.with_name(f"{file_path.name}_{frame}").with_suffix(".png")
cv2.imwrite(frame_file.as_posix(), frame_content)
file_path = frame_file

yield file_path
finally:
for to_unlink in files_to_unlink:
to_unlink.unlink(missing_ok=True)
62 changes: 56 additions & 6 deletions encord_agents/fastapi/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def my_agent(
"""

from pathlib import Path
from typing import Annotated, Callable, Generator, Iterator

import cv2
Expand All @@ -34,7 +35,7 @@ def my_agent(
from encord.user_client import EncordUserClient
from numpy.typing import NDArray

from encord_agents.core.data_model import LabelRowMetadataIncludeArgs
from encord_agents.core.data_model import LabelRowInitialiseLabelsArgs, LabelRowMetadataIncludeArgs
from encord_agents.core.dependencies.shares import DataLookup
from encord_agents.core.vision import crop_to_object

Expand Down Expand Up @@ -76,27 +77,31 @@ def my_route(
return get_user_client()


def dep_label_row_with_include_args(
def dep_label_row_with_args(
label_row_metadata_include_args: LabelRowMetadataIncludeArgs | None = None,
label_row_initialise_labels_args: LabelRowInitialiseLabelsArgs | None = None,
) -> Callable[[FrameData], LabelRowV2]:
"""
Dependency to provide an initialized label row.
**Example:**
```python
from encord_agents.core.data_model import LabelRowMetadataIncludeArgs
from encord_agents.fastapi.depencencies import dep_label_row_with_include_args
from encord_agents.core.data_model import LabelRowMetadataIncludeArgs, LabelRowInitialiseLabelsArgs
from encord_agents.fastapi.depencencies import dep_label_row_with_args
...
include_args = LabelRowMetadataIncludeArgs(
include_client_metadata=True,
include_workflow_graph_node=True,
)
init_args = LabelRowInitialiseLabelsArgs(
include_signed_url=True,
)
@app.post("/my-route")
def my_route(
lr: Annotated[LabelRowV2, Depends(dep_label_row_with_include_args(include_args))]
lr: Annotated[LabelRowV2, Depends(dep_label_row_with_args(include_args, init_args))]
):
assert lr.is_labelling_initialised # will work
assert lr.client_metadata # will be available if set already
Expand All @@ -113,7 +118,9 @@ def my_route(
"""

def wrapper(frame_data: Annotated[FrameData, Form()]) -> LabelRowV2:
return get_initialised_label_row(frame_data, label_row_metadata_include_args)
return get_initialised_label_row(
frame_data, include_args=label_row_metadata_include_args, init_args=label_row_initialise_labels_args
)

return wrapper

Expand Down Expand Up @@ -182,6 +189,49 @@ def my_route(
return np.asarray(img, dtype=np.uint8)


def dep_asset(
lr: Annotated[
LabelRowV2,
Depends(
dep_label_row_with_args(
label_row_initialise_labels_args=LabelRowInitialiseLabelsArgs(include_signed_url=True)
)
),
],
) -> Generator[Path, None, None]:
"""
Get a local file path to data asset temporarily stored till end of agent execution.
This dependency will fetch the underlying data asset based on a signed url.
It will temporarily store the data on disk. Once the task is completed, the
asset will be removed from disk again.
**Example:**
```python
from encord_agents.fastapi.depencencies import dep_asset
...
runner = Runner(project_hash="<project_hash_a>")
@app.post("/my-route")
def my_agent(
asset: Annotated[Path, Depends(dep_asset)],
) -> str | None:
asset.stat() # read file stats
...
```
Returns:
The path to the asset.
Raises:
`ValueError` if the underlying assets are not videos, images, or audio.
`EncordException` if data type not supported by SDK yet.
"""
with download_asset(lr) as asset:
yield asset


def dep_video_iterator(lr: Annotated[LabelRowV2, Depends(dep_label_row)]) -> Generator[Iterator[Frame], None, None]:
"""
Dependency to inject a video frame iterator for doing things over many frames.
Expand Down
38 changes: 37 additions & 1 deletion encord_agents/gcp/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def my_agent(
- [`label_row_v2`](https://docs.encord.com/sdk-documentation/sdk-references/LabelRowV2) is automatically loaded based on the frame data.
"""

from pathlib import Path
from typing import Callable, Generator, Iterator

import cv2
Expand Down Expand Up @@ -106,11 +107,46 @@ def my_agent(
return np.asarray(img, dtype=np.uint8)


def dep_asset(lr: LabelRowV2) -> Generator[Path, None, None]:
"""
Get a local file path to data asset temporarily stored till end of agent execution.
This dependency will fetch the underlying data asset based on a signed url.
It will temporarily store the data on disk. Once the task is completed, the
asset will be removed from disk again.
**Example:**
```python
from encord_agents.gcp import editor_agent
from encord_agents.gcp.dependencies import dep_asset
...
runner = Runner(project_hash="<project_hash_a>")
@editor_agent()
def my_agent(
asset: Annotated[Path, Depends(dep_asset)]
) -> None:
asset.stat() # read file stats
...
```
Returns:
The path to the asset.
Raises:
`ValueError` if the underlying assets are not videos, images, or audio.
`EncordException` if data type not supported by SDK yet.
"""
with download_asset(lr) as asset:
yield asset


def dep_video_iterator(lr: LabelRowV2) -> Generator[Iterator[Frame], None, None]:
"""
Dependency to inject a video frame iterator for doing things over many frames.
**Intended use**
**Example:**
```python
from encord_agents import FrameData
Expand Down
11 changes: 8 additions & 3 deletions encord_agents/gcp/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from flask import Request, Response, make_response

from encord_agents import FrameData
from encord_agents.core.data_model import LabelRowMetadataIncludeArgs
from encord_agents.core.data_model import LabelRowInitialiseLabelsArgs, LabelRowMetadataIncludeArgs
from encord_agents.core.dependencies.models import Context
from encord_agents.core.dependencies.utils import get_dependant, solve_dependencies
from encord_agents.core.utils import get_user_client
Expand All @@ -27,7 +27,9 @@ def generate_response() -> Response:


def editor_agent(
*, label_row_metadata_include_args: LabelRowMetadataIncludeArgs | None = None
*,
label_row_metadata_include_args: LabelRowMetadataIncludeArgs | None = None,
label_row_initialise_labels_args: LabelRowInitialiseLabelsArgs | None = None,
) -> Callable[[AgentFunction], Callable[[Request], Response]]:
"""
Wrapper to make resources available for gcp editor agents.
Expand All @@ -38,6 +40,8 @@ def editor_agent(
Args:
label_row_metadata_include_args: arguments to overwrite default arguments
on `project.list_label_rows_v2()`.
label_row_initialise_labels_args: Arguments to overwrite default arguments
on `label_row.initialise_labels(...)`
Returns:
A wrapped function suitable for gcp functions.
Expand All @@ -57,10 +61,11 @@ def wrapper(request: Request) -> Response:
label_row: LabelRowV2 | None = None
if dependant.needs_label_row:
include_args = label_row_metadata_include_args or LabelRowMetadataIncludeArgs()
init_args = label_row_initialise_labels_args or LabelRowInitialiseLabelsArgs()
label_row = project.list_label_rows_v2(
data_hashes=[str(frame_data.data_hash)], **include_args.model_dump()
)[0]
label_row.initialise_labels(include_signed_url=True)
label_row.initialise_labels(**init_args.model_dump())

context = Context(project=project, label_row=label_row, frame_data=frame_data)
with ExitStack() as stack:
Expand Down
Loading

0 comments on commit 56597d4

Please sign in to comment.