diff --git a/aeon/io/reader.py b/aeon/io/reader.py index abb6b97e..d44c2995 100644 --- a/aeon/io/reader.py +++ b/aeon/io/reader.py @@ -304,18 +304,37 @@ class (int): Int ID of a subject in the environment. """ def __init__(self, pattern: str, model_root: str = "/ceph/aeon/aeon/data/processed"): - """Pose reader constructor.""" - # `pattern` for this reader should typically be '_*' + """Pose reader constructor. + + The pattern for this reader should typically be `__*`. + If a register prefix is required, the pattern should end with a trailing + underscore, e.g. `Camera_202_*`. Otherwise, the pattern should include a + common prefix for the pose model folder excluding the trailing underscore, + e.g. `Camera_model-dir*`. + """ super().__init__(pattern, columns=None) self._model_root = model_root + self._pattern_offset = pattern.rfind("_") + 1 def read(self, file: Path) -> pd.DataFrame: """Reads data from the Harp-binarized tracking file.""" # Get config file from `file`, then bodyparts from config file. - model_dir = Path(*Path(file.stem.replace("_", "/")).parent.parts[-4:]) - config_file_dir = Path(self._model_root) / model_dir - if not config_file_dir.exists(): - raise FileNotFoundError(f"Cannot find model dir {config_file_dir}") + model_dir = Path(file.stem[self._pattern_offset :].replace("_", "/")).parent + + # Check if model directory exists in local or shared directories. + # Local directory is prioritized over shared directory. + local_config_file_dir = file.parent / model_dir + shared_config_file_dir = Path(self._model_root) / model_dir + if local_config_file_dir.exists(): + config_file_dir = local_config_file_dir + elif shared_config_file_dir.exists(): + config_file_dir = shared_config_file_dir + else: + raise FileNotFoundError( + f"""Cannot find model dir in either local ({local_config_file_dir}) \ + or shared ({shared_config_file_dir}) directories""" + ) + config_file = self.get_config_file(config_file_dir) identities = self.get_class_names(config_file) parts = self.get_bodyparts(config_file) @@ -350,7 +369,7 @@ def read(self, file: Path) -> pd.DataFrame: parts = unique_parts # Set new columns, and reformat `data`. - data = self.class_int2str(data, config_file) + data = self.class_int2str(data, identities) n_parts = len(parts) part_data_list = [pd.DataFrame()] * n_parts new_columns = pd.Series(["identity", "identity_likelihood", "part", "x", "y", "part_likelihood"]) @@ -407,18 +426,12 @@ def get_bodyparts(config_file: Path) -> list[str]: return parts @staticmethod - def class_int2str(data: pd.DataFrame, config_file: Path) -> pd.DataFrame: + def class_int2str(data: pd.DataFrame, classes: list[str]) -> pd.DataFrame: """Converts a class integer in a tracking data dataframe to its associated string (subject id).""" - if config_file.stem == "confmap_config": # SLEAP - with open(config_file) as f: - config = json.load(f) - try: - heads = config["model"]["heads"] - classes = util.find_nested_key(heads, "classes") - except KeyError as err: - raise KeyError(f"Cannot find classes in {config_file}.") from err - for i, subj in enumerate(classes): - data.loc[data["identity"] == i, "identity"] = subj + if not classes: + raise ValueError("Classes list cannot be None or empty.") + identity_mapping = dict(enumerate(classes)) + data["identity"] = data["identity"].replace(identity_mapping) return data @classmethod diff --git a/aeon/util.py b/aeon/util.py index 2251eaad..f3e91b7a 100644 --- a/aeon/util.py +++ b/aeon/util.py @@ -14,7 +14,7 @@ def find_nested_key(obj: dict | list, key: str) -> Any: found = find_nested_key(v, key) if found: return found - else: + elif obj is not None: for item in obj: found = find_nested_key(item, key) if found: diff --git a/tests/data/pose/2024-03-01T16-46-12/CameraTop/CameraTop_202_test-node1_topdown-multianimal-id-133_2024-03-02T12-00-00.bin b/tests/data/pose/2024-03-01T16-46-12/CameraTop/CameraTop_202_test-node1_topdown-multianimal-id-133_2024-03-02T12-00-00.bin new file mode 100644 index 00000000..55f13c0f Binary files /dev/null and b/tests/data/pose/2024-03-01T16-46-12/CameraTop/CameraTop_202_test-node1_topdown-multianimal-id-133_2024-03-02T12-00-00.bin differ diff --git a/tests/data/pose/2024-03-01T16-46-12/CameraTop/CameraTop_test-node1_topdown-multianimal-id-133_2024-03-02T12-00-00.bin b/tests/data/pose/2024-03-01T16-46-12/CameraTop/CameraTop_test-node1_topdown-multianimal-id-133_2024-03-02T12-00-00.bin new file mode 100644 index 00000000..806424a8 Binary files /dev/null and b/tests/data/pose/2024-03-01T16-46-12/CameraTop/CameraTop_test-node1_topdown-multianimal-id-133_2024-03-02T12-00-00.bin differ diff --git a/tests/data/pose/2024-03-01T16-46-12/CameraTop/test-node1/topdown-multianimal-id-133/confmap_config.json b/tests/data/pose/2024-03-01T16-46-12/CameraTop/test-node1/topdown-multianimal-id-133/confmap_config.json new file mode 100644 index 00000000..5a2084b0 --- /dev/null +++ b/tests/data/pose/2024-03-01T16-46-12/CameraTop/test-node1/topdown-multianimal-id-133/confmap_config.json @@ -0,0 +1,202 @@ +{ + "data": { + "labels": { + "training_labels": "social_dev_b5350ff/aeon3_social_dev_b5350ff_ceph.slp", + "validation_labels": null, + "validation_fraction": 0.1, + "test_labels": null, + "split_by_inds": false, + "training_inds": null, + "validation_inds": null, + "test_inds": null, + "search_path_hints": [], + "skeletons": [ + { + "directed": true, + "graph": { + "name": "Skeleton-1", + "num_edges_inserted": 0 + }, + "links": [], + "multigraph": true, + "nodes": [ + { + "id": { + "py/object": "sleap.skeleton.Node", + "py/state": { + "py/tuple": [ + "centroid", + 1.0 + ] + } + } + } + ] + } + ] + }, + "preprocessing": { + "ensure_rgb": false, + "ensure_grayscale": false, + "imagenet_mode": null, + "input_scaling": 1.0, + "pad_to_stride": 16, + "resize_and_pad_to_target": true, + "target_height": 1080, + "target_width": 1440 + }, + "instance_cropping": { + "center_on_part": "centroid", + "crop_size": 96, + "crop_size_detection_padding": 16 + } + }, + "model": { + "backbone": { + "leap": null, + "unet": { + "stem_stride": null, + "max_stride": 16, + "output_stride": 2, + "filters": 16, + "filters_rate": 1.5, + "middle_block": true, + "up_interpolate": false, + "stacks": 1 + }, + "hourglass": null, + "resnet": null, + "pretrained_encoder": null + }, + "heads": { + "single_instance": null, + "centroid": null, + "centered_instance": null, + "multi_instance": null, + "multi_class_bottomup": null, + "multi_class_topdown": { + "confmaps": { + "anchor_part": "centroid", + "part_names": [ + "centroid" + ], + "sigma": 1.5, + "output_stride": 2, + "loss_weight": 1.0, + "offset_refinement": false + }, + "class_vectors": { + "classes": [ + "BAA-1104045", + "BAA-1104047" + ], + "num_fc_layers": 3, + "num_fc_units": 256, + "global_pool": true, + "output_stride": 2, + "loss_weight": 0.01 + } + } + }, + "base_checkpoint": null + }, + "optimization": { + "preload_data": true, + "augmentation_config": { + "rotate": true, + "rotation_min_angle": -180.0, + "rotation_max_angle": 180.0, + "translate": false, + "translate_min": -5, + "translate_max": 5, + "scale": false, + "scale_min": 0.9, + "scale_max": 1.1, + "uniform_noise": false, + "uniform_noise_min_val": 0.0, + "uniform_noise_max_val": 10.0, + "gaussian_noise": false, + "gaussian_noise_mean": 5.0, + "gaussian_noise_stddev": 1.0, + "contrast": false, + "contrast_min_gamma": 0.5, + "contrast_max_gamma": 2.0, + "brightness": false, + "brightness_min_val": 0.0, + "brightness_max_val": 10.0, + "random_crop": false, + "random_crop_height": 256, + "random_crop_width": 256, + "random_flip": false, + "flip_horizontal": true + }, + "online_shuffling": true, + "shuffle_buffer_size": 128, + "prefetch": true, + "batch_size": 4, + "batches_per_epoch": 469, + "min_batches_per_epoch": 200, + "val_batches_per_epoch": 54, + "min_val_batches_per_epoch": 10, + "epochs": 200, + "optimizer": "adam", + "initial_learning_rate": 0.0001, + "learning_rate_schedule": { + "reduce_on_plateau": true, + "reduction_factor": 0.1, + "plateau_min_delta": 1e-08, + "plateau_patience": 20, + "plateau_cooldown": 3, + "min_learning_rate": 1e-08 + }, + "hard_keypoint_mining": { + "online_mining": false, + "hard_to_easy_ratio": 2.0, + "min_hard_keypoints": 2, + "max_hard_keypoints": null, + "loss_scale": 5.0 + }, + "early_stopping": { + "stop_training_on_plateau": true, + "plateau_min_delta": 1e-08, + "plateau_patience": 20 + } + }, + "outputs": { + "save_outputs": true, + "run_name": "aeon3_social_dev_b5350ff_ceph_topdown_top.centered_instance_multiclass", + "run_name_prefix": "", + "run_name_suffix": "", + "runs_folder": "social_dev_b5350ff/models", + "tags": [], + "save_visualizations": true, + "delete_viz_images": true, + "zip_outputs": false, + "log_to_csv": true, + "checkpointing": { + "initial_model": true, + "best_model": true, + "every_epoch": false, + "latest_model": false, + "final_model": false + }, + "tensorboard": { + "write_logs": false, + "loss_frequency": "epoch", + "architecture_graph": false, + "profile_graph": false, + "visualizations": true + }, + "zmq": { + "subscribe_to_controller": false, + "controller_address": "tcp://127.0.0.1:9000", + "controller_polling_timeout": 10, + "publish_updates": false, + "publish_address": "tcp://127.0.0.1:9001" + } + }, + "name": "", + "description": "", + "sleap_version": "1.3.1", + "filename": "Z:/aeon/data/processed/test-node1/4310907/2024-01-12T19-00-00/topdown-multianimal-id-133/confmap_config.json" +} \ No newline at end of file diff --git a/tests/io/test_reader.py b/tests/io/test_reader.py new file mode 100644 index 00000000..640768ab --- /dev/null +++ b/tests/io/test_reader.py @@ -0,0 +1,25 @@ +from pathlib import Path + +import pytest +from pytest import mark + +import aeon +from aeon.schema.schemas import social02, social03 + +pose_path = Path(__file__).parent.parent / "data" / "pose" + + +@mark.api +def test_Pose_read_local_model_dir(): + data = aeon.load(pose_path, social02.CameraTop.Pose) + assert len(data) > 0 + + +@mark.api +def test_Pose_read_local_model_dir_with_register_prefix(): + data = aeon.load(pose_path, social03.CameraTop.Pose) + assert len(data) > 0 + + +if __name__ == "__main__": + pytest.main()