-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: init label row args * feat: let download_asset also download audio files * feat: allow getting frame from image sequence * feat: add dep_asset to tasks * feat: add asset dependency * fix: use temporary directory for data storage * fix: dependency issue for fastapi that didn't know how to resolve label row
- Loading branch information
1 parent
96a3dbf
commit 56597d4
Showing
7 changed files
with
223 additions
and
58 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,15 +2,16 @@ | |
from contextlib import contextmanager | ||
from functools import lru_cache | ||
from pathlib import Path | ||
from typing import Any, Generator | ||
from tempfile import TemporaryDirectory | ||
from typing import Any, Generator, cast | ||
|
||
import cv2 | ||
import requests | ||
from encord.constants.enums import DataType | ||
from encord.objects.ontology_labels_impl import LabelRowV2 | ||
from encord.user_client import EncordUserClient | ||
|
||
from encord_agents.core.data_model import FrameData, LabelRowMetadataIncludeArgs | ||
from encord_agents.core.data_model import FrameData, LabelRowInitialiseLabelsArgs, LabelRowMetadataIncludeArgs | ||
from encord_agents.core.settings import Settings | ||
|
||
from .video import get_frame | ||
|
@@ -31,7 +32,9 @@ def get_user_client() -> EncordUserClient: | |
|
||
|
||
def get_initialised_label_row( | ||
frame_data: FrameData, include_args: LabelRowMetadataIncludeArgs | None = None | ||
frame_data: FrameData, | ||
include_args: LabelRowMetadataIncludeArgs | None = None, | ||
init_args: LabelRowInitialiseLabelsArgs | None = None, | ||
) -> LabelRowV2: | ||
""" | ||
Get an initialised label row from the frame_data information. | ||
|
@@ -49,18 +52,19 @@ def get_initialised_label_row( | |
user_client = get_user_client() | ||
project = user_client.get_project(str(frame_data.project_hash)) | ||
include_args = include_args or LabelRowMetadataIncludeArgs() | ||
init_args = init_args or LabelRowInitialiseLabelsArgs() | ||
matched_lrs = project.list_label_rows_v2(data_hashes=[frame_data.data_hash], **include_args.model_dump()) | ||
num_matches = len(matched_lrs) | ||
if num_matches > 1: | ||
raise Exception(f"Non unique match: matched {num_matches} label rows!") | ||
elif num_matches == 0: | ||
raise Exception("No label rows were matched!") | ||
lr = matched_lrs.pop() | ||
lr.initialise_labels(include_signed_url=True) | ||
lr.initialise_labels(**init_args.model_dump()) | ||
return lr | ||
|
||
|
||
def _guess_file_suffix(url: str, lr: LabelRowV2) -> str: | ||
def _guess_file_suffix(url: str, lr: LabelRowV2) -> tuple[str, str]: | ||
""" | ||
Best effort attempt to guess file suffix given a url and label row. | ||
|
@@ -75,8 +79,8 @@ def _guess_file_suffix(url: str, lr: LabelRowV2) -> str: | |
- lr: the associated label row | ||
Returns: | ||
A file suffix that can be used to store the file. For example, ".jpg" or ".mp4" | ||
A file type and suffix that can be used to store the file. | ||
For example, ("image", ".jpg") or ("video", ".mp4"). | ||
""" | ||
fallback_mimetype = "video/mp4" if lr.data_type == DataType.VIDEO else "image/png" | ||
mimetype, _ = next( | ||
|
@@ -95,21 +99,23 @@ def _guess_file_suffix(url: str, lr: LabelRowV2) -> str: | |
|
||
file_type, suffix = mimetype.split("/")[:2] | ||
|
||
if file_type == "video" and lr.data_type != DataType.VIDEO: | ||
if (file_type == "audio" and lr.data_type != DataType.AUDIO) or ( | ||
file_type == "video" and lr.data_type != DataType.VIDEO | ||
): | ||
raise ValueError(f"Mimetype {mimetype} and lr data type {lr.data_type} did not match") | ||
elif file_type == "image" and lr.data_type not in { | ||
DataType.IMG_GROUP, | ||
DataType.IMAGE, | ||
}: | ||
raise ValueError(f"Mimetype {mimetype} and lr data type {lr.data_type} did not match") | ||
elif file_type not in {"image", "video"}: | ||
raise ValueError("File type not video or image") | ||
elif file_type not in {"image", "video", "audio"}: | ||
raise ValueError("File type not audio, video, or image") | ||
|
||
return f".{suffix}" | ||
return file_type, f".{suffix}" | ||
|
||
|
||
@contextmanager | ||
def download_asset(lr: LabelRowV2, frame: int | None) -> Generator[Path, None, None]: | ||
def download_asset(lr: LabelRowV2, frame: int | None = None) -> Generator[Path, None, None]: | ||
""" | ||
Download the asset associated to a label row to disk. | ||
|
@@ -136,35 +142,47 @@ def download_asset(lr: LabelRowV2, frame: int | None) -> Generator[Path, None, N | |
The file path for the requested asset. | ||
""" | ||
video_item, images_list = lr._project_client.get_data(lr.data_hash, get_signed_url=True) | ||
if lr.data_type in [DataType.VIDEO, DataType.IMAGE] and video_item: | ||
url = video_item["file_link"] | ||
elif lr.data_type == DataType.IMG_GROUP and images_list: | ||
url: str | None = None | ||
if lr.data_link is not None and lr.data_link[:5] == "https": | ||
url = lr.data_link | ||
elif lr.backing_item_uuid is not None: | ||
storage_item = get_user_client().get_storage_item(lr.backing_item_uuid, sign_url=True) | ||
url = storage_item.get_signed_url() | ||
|
||
# Fallback for native image groups (they don't have a url) | ||
is_image_sequence = lr.data_type == DataType.IMG_GROUP | ||
if url is None: | ||
is_image_sequence = False | ||
_, images_list = lr._project_client.get_data(lr.data_hash, get_signed_url=True) | ||
if images_list is None: | ||
raise ValueError("Image list should not be none for image groups.") | ||
if frame is None: | ||
raise NotImplementedError( | ||
"Downloading entire image group is not supported. Please contact Encord at [email protected] for help or submit a PR with an implementation." | ||
) | ||
url = images_list[frame]["file_link"] | ||
else: | ||
raise ValueError(f"Couldn't load asset of type {lr.data_type}") | ||
image = images_list[frame] | ||
url = cast(str | None, image.file_link) | ||
|
||
if url is None: | ||
raise ValueError("Failed to get a signed url for the asset") | ||
|
||
response = requests.get(url) | ||
response.raise_for_status() | ||
|
||
suffix = _guess_file_suffix(url, lr) | ||
file_path = Path(lr.data_hash).with_suffix(suffix) | ||
with open(file_path, "wb") as f: | ||
f.write(response.content) | ||
|
||
files_to_unlink = [file_path] | ||
if lr.data_type == DataType.VIDEO and frame is not None: # Get that exact frame | ||
frame_content = get_frame(file_path, frame) | ||
frame_file = file_path.with_name(f"{file_path.name}_{frame}").with_suffix(".png") | ||
cv2.imwrite(frame_file.as_posix(), frame_content) | ||
files_to_unlink.append(frame_file) | ||
file_path = frame_file | ||
try: | ||
with TemporaryDirectory() as dir_name: | ||
dir_path = Path(dir_name) | ||
|
||
_, suffix = _guess_file_suffix(url, lr) | ||
file_path = dir_path / f"{lr.data_hash}{suffix}" | ||
with open(file_path, "wb") as f: | ||
for chunk in response.iter_content(chunk_size=4096): | ||
if chunk: | ||
f.write(chunk) | ||
|
||
if (lr.data_type == DataType.VIDEO or is_image_sequence) and frame is not None: # Get that exact frame | ||
frame_content = get_frame(file_path, frame) | ||
frame_file = file_path.with_name(f"{file_path.name}_{frame}").with_suffix(".png") | ||
cv2.imwrite(frame_file.as_posix(), frame_content) | ||
file_path = frame_file | ||
|
||
yield file_path | ||
finally: | ||
for to_unlink in files_to_unlink: | ||
to_unlink.unlink(missing_ok=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.