diff --git a/tti_eval/constants.py b/tti_eval/constants.py index 33e2e42..4665f21 100644 --- a/tti_eval/constants.py +++ b/tti_eval/constants.py @@ -8,9 +8,7 @@ # If the cache directory is not explicitly specified, use the `.cache` directory located in the project's root. _TTI_EVAL_ROOT_DIR = Path(__file__).parent.parent CACHE_PATH = Path(os.environ.get("TTI_EVAL_CACHE_PATH", _TTI_EVAL_ROOT_DIR / ".cache")) -_OUTPUT_PATH = Path( - os.environ.get("TTI_EVAL_OUTPUT_PATH", _TTI_EVAL_ROOT_DIR / "output") -) +_OUTPUT_PATH = Path(os.environ.get("TTI_EVAL_OUTPUT_PATH", _TTI_EVAL_ROOT_DIR / "output")) _SOURCES_PATH = _TTI_EVAL_ROOT_DIR / "sources" diff --git a/tti_eval/dataset/types/encord_ds.py b/tti_eval/dataset/types/encord_ds.py index 4c70b21..0e7dd96 100644 --- a/tti_eval/dataset/types/encord_ds.py +++ b/tti_eval/dataset/types/encord_ds.py @@ -1,6 +1,10 @@ import json +import multiprocessing import os +from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass +from functools import partial from pathlib import Path from typing import Any @@ -30,7 +34,13 @@ def __init__( ssh_key_path: str | None = None, **kwargs, ): - super().__init__(title, split=split, title_in_source=title_in_source, transform=transform, cache_dir=cache_dir) + super().__init__( + title, + split=split, + title_in_source=title_in_source, + transform=transform, + cache_dir=cache_dir, + ) self._setup(project_hash, classification_hash, ssh_key_path, **kwargs) def __getitem__(self, idx): @@ -191,6 +201,37 @@ def _download_label_row_image_data(data_dir: Path, project: Project, label_row: ) +def _download_label_row( + label_row: LabelRowV2, + project: Project, + data_dir: Path, + overwrite_annotations: bool, + label_rows_info: dict[str, Any], + update_pbar: Callable[[], Any], +): + if label_row.data_type not in {DataType.IMAGE, DataType.IMG_GROUP}: + return + save_annotations = False + # Trigger the images download if the label hash is not found or is None (never downloaded). + if label_row.label_hash not in label_rows_info.keys(): + _download_label_row_image_data(data_dir, project, label_row) + save_annotations = True + # Overwrite annotations only if `last_edited_at` values differ between the existing and new annotations. + elif ( + overwrite_annotations + and label_row.last_edited_at.strftime(DATETIME_STRING_FORMAT) + != label_rows_info[label_row.label_hash]["last_edited_at"] + ): + label_row.initialise_labels() + save_annotations = True + + if save_annotations: + annotations_file = get_label_row_annotations_file(data_dir, project.project_hash, label_row.label_hash) + annotations_file.write_text(json.dumps(label_row.to_encord_dict()), encoding="utf-8") + label_rows_info[label_row.label_hash] = {"last_edited_at": label_row.last_edited_at} + update_pbar() + + def _download_label_rows( project: Project, data_dir: Path, @@ -202,27 +243,18 @@ def _download_label_rows( if tqdm_desc is None: tqdm_desc = f"Downloading data from Encord project `{project.title}`" - for label_row in tqdm(label_rows, desc=tqdm_desc): - if label_row.data_type not in {DataType.IMAGE, DataType.IMG_GROUP}: - continue - save_annotations = False - # Trigger the images download if the label hash is not found or is None (never downloaded). - if label_row.label_hash not in label_rows_info.keys(): - _download_label_row_image_data(data_dir, project, label_row) - save_annotations = True - # Overwrite annotations only if `last_edited_at` values differ between the existing and new annotations. - elif ( - overwrite_annotations - and label_row.last_edited_at.strftime(DATETIME_STRING_FORMAT) - != label_rows_info[label_row.label_hash]["last_edited_at"] - ): - label_row.initialise_labels() - save_annotations = True - - if save_annotations: - annotations_file = get_label_row_annotations_file(data_dir, project.project_hash, label_row.label_hash) - annotations_file.write_text(json.dumps(label_row.to_encord_dict()), encoding="utf-8") - label_rows_info[label_row.label_hash] = {"last_edited_at": label_row.last_edited_at} + pbar = tqdm(total=len(label_rows), desc=tqdm_desc) + _do_download = partial( + _download_label_row, + project=project, + data_dir=data_dir, + overwrite_annotations=overwrite_annotations, + label_rows_info=label_rows_info, + update_pbar=lambda: pbar.update(1), + ) + + with ThreadPoolExecutor(min(multiprocessing.cpu_count(), 24)) as exe: + exe.map(_do_download, label_rows) def download_data_from_project( @@ -293,7 +325,11 @@ def get_frame_file(data_dir: Path, project_hash: str, label_row: LabelRowV2, fra def get_frame_file_raw( - data_dir: Path, project_hash: str, label_row_hash: str, frame_hash: str, frame_title: str + data_dir: Path, + project_hash: str, + label_row_hash: str, + frame_hash: str, + frame_title: str, ) -> Path: return get_label_row_dir(data_dir, project_hash, label_row_hash) / get_frame_name(frame_hash, frame_title) diff --git a/tti_eval/evaluation/image_retrieval.py b/tti_eval/evaluation/image_retrieval.py index 78308ff..b8937f0 100644 --- a/tti_eval/evaluation/image_retrieval.py +++ b/tti_eval/evaluation/image_retrieval.py @@ -56,9 +56,7 @@ def evaluate(self) -> float: # To compute retrieval accuracy, we ensure that a maximum of Q elements per sample are retrieved, # where Q represents the size of the respective class in the validation embeddings - top_nearest_per_class = np.where( - self._class_counts < self.k, self._class_counts, self.k - ) + top_nearest_per_class = np.where(self._class_counts < self.k, self._class_counts, self.k) top_nearest_per_sample = top_nearest_per_class[self._train_embeddings.labels] # Add a placeholder value for indices outside the retrieval scope diff --git a/tti_eval/evaluation/knn.py b/tti_eval/evaluation/knn.py index f441d39..b6c0aee 100644 --- a/tti_eval/evaluation/knn.py +++ b/tti_eval/evaluation/knn.py @@ -59,9 +59,7 @@ def predict(self) -> tuple[ProbabilityArray, ClassArray]: # Calculate class votes from the distances (avoiding division by zero) # Note: Values stored in `dists` are the squared 2-norm values of the respective distance vectors max_value = np.finfo(np.float32).max - scores = np.divide( - 1, dists, out=np.full_like(dists, max_value), where=dists != 0 - ) + scores = np.divide(1, dists, out=np.full_like(dists, max_value), where=dists != 0) # NOTE: if self.k and self.num_classes are both large, this might become a big one. # We can shape of a factor self.k if we count differently here. n = len(self._val_embeddings.images)