From 56597d4eacd16c08a7dbfcde2ec16a82e7211a23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Frederik=20Hvilsh=C3=B8j?= <93145535+frederik-encord@users.noreply.github.com> Date: Tue, 17 Dec 2024 13:34:30 +0100 Subject: [PATCH] feat: add asset dependency (#34) * feat: init label row args * feat: let download_asset also download audio files * feat: allow getting frame from image sequence * feat: add dep_asset to tasks * feat: add asset dependency * fix: use temporary directory for data storage * fix: dependency issue for fastapi that didn't know how to resolve label row --- encord_agents/core/data_model.py | 14 +++++ encord_agents/core/utils.py | 88 ++++++++++++++++----------- encord_agents/fastapi/dependencies.py | 62 +++++++++++++++++-- encord_agents/gcp/dependencies.py | 38 +++++++++++- encord_agents/gcp/wrappers.py | 11 +++- encord_agents/tasks/dependencies.py | 35 +++++++++++ encord_agents/tasks/runner.py | 33 ++++++---- 7 files changed, 223 insertions(+), 58 deletions(-) diff --git a/encord_agents/core/data_model.py b/encord_agents/core/data_model.py index 7e8b6a1..6701d84 100644 --- a/encord_agents/core/data_model.py +++ b/encord_agents/core/data_model.py @@ -26,6 +26,20 @@ class LabelRowMetadataIncludeArgs(BaseModel): include_all_label_branches: bool = False +class LabelRowInitialiseLabelsArgs(BaseModel): + """ + Arguments used to specify how to initialise labels via the SDK. + + The arguments are passed to `LabelRowV2.initialise_labels`. + """ + + include_object_feature_hashes: set[str] | None = None + include_classification_feature_hashes: set[str] | None = None + include_reviews: bool = False + overwrite: bool = False + include_signed_url: bool = False + + class FrameData(BaseModel): """ Holds the data sent from the Encord Label Editor at the time of triggering the agent. diff --git a/encord_agents/core/utils.py b/encord_agents/core/utils.py index b6f1df6..ab8bdb9 100644 --- a/encord_agents/core/utils.py +++ b/encord_agents/core/utils.py @@ -2,7 +2,8 @@ from contextlib import contextmanager from functools import lru_cache from pathlib import Path -from typing import Any, Generator +from tempfile import TemporaryDirectory +from typing import Any, Generator, cast import cv2 import requests @@ -10,7 +11,7 @@ from encord.objects.ontology_labels_impl import LabelRowV2 from encord.user_client import EncordUserClient -from encord_agents.core.data_model import FrameData, LabelRowMetadataIncludeArgs +from encord_agents.core.data_model import FrameData, LabelRowInitialiseLabelsArgs, LabelRowMetadataIncludeArgs from encord_agents.core.settings import Settings from .video import get_frame @@ -31,7 +32,9 @@ def get_user_client() -> EncordUserClient: def get_initialised_label_row( - frame_data: FrameData, include_args: LabelRowMetadataIncludeArgs | None = None + frame_data: FrameData, + include_args: LabelRowMetadataIncludeArgs | None = None, + init_args: LabelRowInitialiseLabelsArgs | None = None, ) -> LabelRowV2: """ Get an initialised label row from the frame_data information. @@ -49,6 +52,7 @@ def get_initialised_label_row( user_client = get_user_client() project = user_client.get_project(str(frame_data.project_hash)) include_args = include_args or LabelRowMetadataIncludeArgs() + init_args = init_args or LabelRowInitialiseLabelsArgs() matched_lrs = project.list_label_rows_v2(data_hashes=[frame_data.data_hash], **include_args.model_dump()) num_matches = len(matched_lrs) if num_matches > 1: @@ -56,11 +60,11 @@ def get_initialised_label_row( elif num_matches == 0: raise Exception("No label rows were matched!") lr = matched_lrs.pop() - lr.initialise_labels(include_signed_url=True) + lr.initialise_labels(**init_args.model_dump()) return lr -def _guess_file_suffix(url: str, lr: LabelRowV2) -> str: +def _guess_file_suffix(url: str, lr: LabelRowV2) -> tuple[str, str]: """ Best effort attempt to guess file suffix given a url and label row. @@ -75,8 +79,8 @@ def _guess_file_suffix(url: str, lr: LabelRowV2) -> str: - lr: the associated label row Returns: - A file suffix that can be used to store the file. For example, ".jpg" or ".mp4" - + A file type and suffix that can be used to store the file. + For example, ("image", ".jpg") or ("video", ".mp4"). """ fallback_mimetype = "video/mp4" if lr.data_type == DataType.VIDEO else "image/png" mimetype, _ = next( @@ -95,21 +99,23 @@ def _guess_file_suffix(url: str, lr: LabelRowV2) -> str: file_type, suffix = mimetype.split("/")[:2] - if file_type == "video" and lr.data_type != DataType.VIDEO: + if (file_type == "audio" and lr.data_type != DataType.AUDIO) or ( + file_type == "video" and lr.data_type != DataType.VIDEO + ): raise ValueError(f"Mimetype {mimetype} and lr data type {lr.data_type} did not match") elif file_type == "image" and lr.data_type not in { DataType.IMG_GROUP, DataType.IMAGE, }: raise ValueError(f"Mimetype {mimetype} and lr data type {lr.data_type} did not match") - elif file_type not in {"image", "video"}: - raise ValueError("File type not video or image") + elif file_type not in {"image", "video", "audio"}: + raise ValueError("File type not audio, video, or image") - return f".{suffix}" + return file_type, f".{suffix}" @contextmanager -def download_asset(lr: LabelRowV2, frame: int | None) -> Generator[Path, None, None]: +def download_asset(lr: LabelRowV2, frame: int | None = None) -> Generator[Path, None, None]: """ Download the asset associated to a label row to disk. @@ -136,35 +142,47 @@ def download_asset(lr: LabelRowV2, frame: int | None) -> Generator[Path, None, N The file path for the requested asset. """ - video_item, images_list = lr._project_client.get_data(lr.data_hash, get_signed_url=True) - if lr.data_type in [DataType.VIDEO, DataType.IMAGE] and video_item: - url = video_item["file_link"] - elif lr.data_type == DataType.IMG_GROUP and images_list: + url: str | None = None + if lr.data_link is not None and lr.data_link[:5] == "https": + url = lr.data_link + elif lr.backing_item_uuid is not None: + storage_item = get_user_client().get_storage_item(lr.backing_item_uuid, sign_url=True) + url = storage_item.get_signed_url() + + # Fallback for native image groups (they don't have a url) + is_image_sequence = lr.data_type == DataType.IMG_GROUP + if url is None: + is_image_sequence = False + _, images_list = lr._project_client.get_data(lr.data_hash, get_signed_url=True) + if images_list is None: + raise ValueError("Image list should not be none for image groups.") if frame is None: raise NotImplementedError( "Downloading entire image group is not supported. Please contact Encord at support@encord.com for help or submit a PR with an implementation." ) - url = images_list[frame]["file_link"] - else: - raise ValueError(f"Couldn't load asset of type {lr.data_type}") + image = images_list[frame] + url = cast(str | None, image.file_link) + + if url is None: + raise ValueError("Failed to get a signed url for the asset") response = requests.get(url) response.raise_for_status() - suffix = _guess_file_suffix(url, lr) - file_path = Path(lr.data_hash).with_suffix(suffix) - with open(file_path, "wb") as f: - f.write(response.content) - - files_to_unlink = [file_path] - if lr.data_type == DataType.VIDEO and frame is not None: # Get that exact frame - frame_content = get_frame(file_path, frame) - frame_file = file_path.with_name(f"{file_path.name}_{frame}").with_suffix(".png") - cv2.imwrite(frame_file.as_posix(), frame_content) - files_to_unlink.append(frame_file) - file_path = frame_file - try: + with TemporaryDirectory() as dir_name: + dir_path = Path(dir_name) + + _, suffix = _guess_file_suffix(url, lr) + file_path = dir_path / f"{lr.data_hash}{suffix}" + with open(file_path, "wb") as f: + for chunk in response.iter_content(chunk_size=4096): + if chunk: + f.write(chunk) + + if (lr.data_type == DataType.VIDEO or is_image_sequence) and frame is not None: # Get that exact frame + frame_content = get_frame(file_path, frame) + frame_file = file_path.with_name(f"{file_path.name}_{frame}").with_suffix(".png") + cv2.imwrite(frame_file.as_posix(), frame_content) + file_path = frame_file + yield file_path - finally: - for to_unlink in files_to_unlink: - to_unlink.unlink(missing_ok=True) diff --git a/encord_agents/fastapi/dependencies.py b/encord_agents/fastapi/dependencies.py index 5c147eb..b3e4af0 100644 --- a/encord_agents/fastapi/dependencies.py +++ b/encord_agents/fastapi/dependencies.py @@ -21,6 +21,7 @@ def my_agent( """ +from pathlib import Path from typing import Annotated, Callable, Generator, Iterator import cv2 @@ -34,7 +35,7 @@ def my_agent( from encord.user_client import EncordUserClient from numpy.typing import NDArray -from encord_agents.core.data_model import LabelRowMetadataIncludeArgs +from encord_agents.core.data_model import LabelRowInitialiseLabelsArgs, LabelRowMetadataIncludeArgs from encord_agents.core.dependencies.shares import DataLookup from encord_agents.core.vision import crop_to_object @@ -76,8 +77,9 @@ def my_route( return get_user_client() -def dep_label_row_with_include_args( +def dep_label_row_with_args( label_row_metadata_include_args: LabelRowMetadataIncludeArgs | None = None, + label_row_initialise_labels_args: LabelRowInitialiseLabelsArgs | None = None, ) -> Callable[[FrameData], LabelRowV2]: """ Dependency to provide an initialized label row. @@ -85,18 +87,21 @@ def dep_label_row_with_include_args( **Example:** ```python - from encord_agents.core.data_model import LabelRowMetadataIncludeArgs - from encord_agents.fastapi.depencencies import dep_label_row_with_include_args + from encord_agents.core.data_model import LabelRowMetadataIncludeArgs, LabelRowInitialiseLabelsArgs + from encord_agents.fastapi.depencencies import dep_label_row_with_args ... include_args = LabelRowMetadataIncludeArgs( include_client_metadata=True, include_workflow_graph_node=True, ) + init_args = LabelRowInitialiseLabelsArgs( + include_signed_url=True, + ) @app.post("/my-route") def my_route( - lr: Annotated[LabelRowV2, Depends(dep_label_row_with_include_args(include_args))] + lr: Annotated[LabelRowV2, Depends(dep_label_row_with_args(include_args, init_args))] ): assert lr.is_labelling_initialised # will work assert lr.client_metadata # will be available if set already @@ -113,7 +118,9 @@ def my_route( """ def wrapper(frame_data: Annotated[FrameData, Form()]) -> LabelRowV2: - return get_initialised_label_row(frame_data, label_row_metadata_include_args) + return get_initialised_label_row( + frame_data, include_args=label_row_metadata_include_args, init_args=label_row_initialise_labels_args + ) return wrapper @@ -182,6 +189,49 @@ def my_route( return np.asarray(img, dtype=np.uint8) +def dep_asset( + lr: Annotated[ + LabelRowV2, + Depends( + dep_label_row_with_args( + label_row_initialise_labels_args=LabelRowInitialiseLabelsArgs(include_signed_url=True) + ) + ), + ], +) -> Generator[Path, None, None]: + """ + Get a local file path to data asset temporarily stored till end of agent execution. + + This dependency will fetch the underlying data asset based on a signed url. + It will temporarily store the data on disk. Once the task is completed, the + asset will be removed from disk again. + + **Example:** + + ```python + from encord_agents.fastapi.depencencies import dep_asset + ... + runner = Runner(project_hash="") + + @app.post("/my-route") + def my_agent( + asset: Annotated[Path, Depends(dep_asset)], + ) -> str | None: + asset.stat() # read file stats + ... + ``` + + Returns: + The path to the asset. + + Raises: + `ValueError` if the underlying assets are not videos, images, or audio. + `EncordException` if data type not supported by SDK yet. + """ + with download_asset(lr) as asset: + yield asset + + def dep_video_iterator(lr: Annotated[LabelRowV2, Depends(dep_label_row)]) -> Generator[Iterator[Frame], None, None]: """ Dependency to inject a video frame iterator for doing things over many frames. diff --git a/encord_agents/gcp/dependencies.py b/encord_agents/gcp/dependencies.py index 8506706..c00ccc7 100644 --- a/encord_agents/gcp/dependencies.py +++ b/encord_agents/gcp/dependencies.py @@ -27,6 +27,7 @@ def my_agent( - [`label_row_v2`](https://docs.encord.com/sdk-documentation/sdk-references/LabelRowV2) is automatically loaded based on the frame data. """ +from pathlib import Path from typing import Callable, Generator, Iterator import cv2 @@ -106,11 +107,46 @@ def my_agent( return np.asarray(img, dtype=np.uint8) +def dep_asset(lr: LabelRowV2) -> Generator[Path, None, None]: + """ + Get a local file path to data asset temporarily stored till end of agent execution. + + This dependency will fetch the underlying data asset based on a signed url. + It will temporarily store the data on disk. Once the task is completed, the + asset will be removed from disk again. + + **Example:** + + ```python + from encord_agents.gcp import editor_agent + from encord_agents.gcp.dependencies import dep_asset + ... + runner = Runner(project_hash="") + + @editor_agent() + def my_agent( + asset: Annotated[Path, Depends(dep_asset)] + ) -> None: + asset.stat() # read file stats + ... + ``` + + Returns: + The path to the asset. + + Raises: + `ValueError` if the underlying assets are not videos, images, or audio. + `EncordException` if data type not supported by SDK yet. + """ + with download_asset(lr) as asset: + yield asset + + def dep_video_iterator(lr: LabelRowV2) -> Generator[Iterator[Frame], None, None]: """ Dependency to inject a video frame iterator for doing things over many frames. - **Intended use** + **Example:** ```python from encord_agents import FrameData diff --git a/encord_agents/gcp/wrappers.py b/encord_agents/gcp/wrappers.py index 06d80da..a5d428d 100644 --- a/encord_agents/gcp/wrappers.py +++ b/encord_agents/gcp/wrappers.py @@ -8,7 +8,7 @@ from flask import Request, Response, make_response from encord_agents import FrameData -from encord_agents.core.data_model import LabelRowMetadataIncludeArgs +from encord_agents.core.data_model import LabelRowInitialiseLabelsArgs, LabelRowMetadataIncludeArgs from encord_agents.core.dependencies.models import Context from encord_agents.core.dependencies.utils import get_dependant, solve_dependencies from encord_agents.core.utils import get_user_client @@ -27,7 +27,9 @@ def generate_response() -> Response: def editor_agent( - *, label_row_metadata_include_args: LabelRowMetadataIncludeArgs | None = None + *, + label_row_metadata_include_args: LabelRowMetadataIncludeArgs | None = None, + label_row_initialise_labels_args: LabelRowInitialiseLabelsArgs | None = None, ) -> Callable[[AgentFunction], Callable[[Request], Response]]: """ Wrapper to make resources available for gcp editor agents. @@ -38,6 +40,8 @@ def editor_agent( Args: label_row_metadata_include_args: arguments to overwrite default arguments on `project.list_label_rows_v2()`. + label_row_initialise_labels_args: Arguments to overwrite default arguments + on `label_row.initialise_labels(...)` Returns: A wrapped function suitable for gcp functions. @@ -57,10 +61,11 @@ def wrapper(request: Request) -> Response: label_row: LabelRowV2 | None = None if dependant.needs_label_row: include_args = label_row_metadata_include_args or LabelRowMetadataIncludeArgs() + init_args = label_row_initialise_labels_args or LabelRowInitialiseLabelsArgs() label_row = project.list_label_rows_v2( data_hashes=[str(frame_data.data_hash)], **include_args.model_dump() )[0] - label_row.initialise_labels(include_signed_url=True) + label_row.initialise_labels(**init_args.model_dump()) context = Context(project=project, label_row=label_row, frame_data=frame_data) with ExitStack() as stack: diff --git a/encord_agents/tasks/dependencies.py b/encord_agents/tasks/dependencies.py index f759d35..60a3bff 100644 --- a/encord_agents/tasks/dependencies.py +++ b/encord_agents/tasks/dependencies.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from pathlib import Path from typing import Callable, Generator, Iterator import cv2 @@ -117,6 +118,40 @@ def my_agent( yield iter_video(asset) +def dep_asset(lr: LabelRowV2) -> Generator[Path, None, None]: + """ + Get a local file path to data asset temporarily stored till end of task execution. + + This dependency will fetch the underlying data asset based on a signed url. + It will temporarily store the data on disk. Once the task is completed, the + asset will be removed from disk again. + + **Example:** + + ```python + from encord_agents.tasks.dependencies import dep_asset + ... + runner = Runner(project_hash="") + + @runner.stage("") + def my_agent( + asset: Annotated[Path, Depends(dep_asset)], + ) -> str | None: + asset.stat() # read file stats + ... + ``` + + Returns: + The path to the asset. + + Raises: + `ValueError` if the underlying assets are not videos, images, or audio. + `EncordException` if data type not supported by SDK yet. + """ + with download_asset(lr) as asset: + yield asset + + @dataclass(frozen=True) class Twin: """ diff --git a/encord_agents/tasks/runner.py b/encord_agents/tasks/runner.py index 893168e..3d9d509 100644 --- a/encord_agents/tasks/runner.py +++ b/encord_agents/tasks/runner.py @@ -19,7 +19,7 @@ from typer import Abort, Option from typing_extensions import Annotated -from encord_agents.core.data_model import LabelRowMetadataIncludeArgs +from encord_agents.core.data_model import LabelRowInitialiseLabelsArgs, LabelRowMetadataIncludeArgs from encord_agents.core.dependencies.models import Context, DecoratedCallable, Dependant from encord_agents.core.dependencies.utils import get_dependant, solve_dependencies from encord_agents.core.utils import get_user_client @@ -35,12 +35,14 @@ def __init__( callable: Callable[..., TaskAgentReturn], printable_name: str | None = None, label_row_metadata_include_args: LabelRowMetadataIncludeArgs | None = None, + label_row_initialise_labels_args: LabelRowInitialiseLabelsArgs | None = None, ): self.identity = identity self.printable_name = printable_name or identity self.callable = callable self.dependant: Dependant = get_dependant(func=callable) self.label_row_metadata_include_args = label_row_metadata_include_args + self.label_row_initialise_labels_args = label_row_initialise_labels_args def __repr__(self) -> str: return f'RunnerAgent("{self.printable_name}")' @@ -130,6 +132,7 @@ def _add_stage_agent( func: Callable[..., TaskAgentReturn], printable_name: str | None, label_row_metadata_include_args: LabelRowMetadataIncludeArgs | None, + label_row_initialise_labels_args: LabelRowInitialiseLabelsArgs | None, ) -> None: self.agents.append( RunnerAgent( @@ -137,11 +140,16 @@ def _add_stage_agent( callable=func, printable_name=printable_name, label_row_metadata_include_args=label_row_metadata_include_args, + label_row_initialise_labels_args=label_row_initialise_labels_args, ) ) def stage( - self, stage: str | UUID, *, label_row_metadata_include_args: LabelRowMetadataIncludeArgs | None = None + self, + stage: str | UUID, + *, + label_row_metadata_include_args: LabelRowMetadataIncludeArgs | None = None, + label_row_initialise_labels_args: LabelRowInitialiseLabelsArgs | None = None, ) -> Callable[[DecoratedCallable], DecoratedCallable]: r""" Decorator to associate a function with an agent stage. @@ -214,6 +222,8 @@ def my_func( associated with. label_row_metadata_include_args: Arguments to be passed to `project.list_label_rows_v2(...)` + label_row_initialise_labels_args: Arguments to be passed to + `label_row.initialise_labels(...)` Returns: The decorated function. @@ -244,7 +254,9 @@ def my_func( ) def decorator(func: DecoratedCallable) -> DecoratedCallable: - self._add_stage_agent(stage, func, printable_name, label_row_metadata_include_args) + self._add_stage_agent( + stage, func, printable_name, label_row_metadata_include_args, label_row_initialise_labels_args + ) return func return decorator @@ -415,6 +427,8 @@ def {fn_name}(...): next_execution = datetime.now() + delta if delta else False for runner_agent in self.agents: + include_args = runner_agent.label_row_metadata_include_args or LabelRowMetadataIncludeArgs() + init_args = runner_agent.label_row_initialise_labels_args or LabelRowInitialiseLabelsArgs() stage = agent_stages[runner_agent.identity] batch: list[AgentTask] = [] @@ -429,9 +443,6 @@ def {fn_name}(...): if len(batch) == task_batch_size: batch_lrs = [None] * len(batch) if runner_agent.dependant.needs_label_row: - include_args = ( - runner_agent.label_row_metadata_include_args or LabelRowMetadataIncludeArgs() - ) label_rows = { UUID(lr.data_hash): lr for lr in project.list_label_rows_v2( @@ -442,7 +453,7 @@ def {fn_name}(...): with project.create_bundle() as lr_bundle: for lr in batch_lrs: if lr: - lr.initialise_labels(bundle=lr_bundle) + lr.initialise_labels(bundle=lr_bundle, **init_args.model_dump()) self._execute_tasks( project, @@ -462,18 +473,14 @@ def {fn_name}(...): UUID(lr.data_hash): lr for lr in project.list_label_rows_v2( data_hashes=[t.data_hash for t in batch], - **( - runner_agent.label_row_metadata_include_args.model_dump() - if runner_agent.label_row_metadata_include_args - else {} - ), + **include_args.model_dump(), ) } batch_lrs = [label_rows[t.data_hash] for t in batch] with project.create_bundle() as lr_bundle: for lr in batch_lrs: if lr: - lr.initialise_labels(bundle=lr_bundle) + lr.initialise_labels(bundle=lr_bundle, **init_args.model_dump()) self._execute_tasks( project, zip(batch, batch_lrs), runner_agent, num_retries, pbar_update=pbar.update )