From f803459a9e73662910e62a1953cb96bdd29546fa Mon Sep 17 00:00:00 2001 From: Jai Bhagat Date: Thu, 14 Sep 2023 17:03:25 +0000 Subject: [PATCH] Merge pull request #253 from SainsburyWellcomeCentre/update_pose_reader --- aeon/schema/social.py | 52 ++++++++++++++++++++++++++++--------------- pyproject.toml | 1 - 2 files changed, 34 insertions(+), 19 deletions(-) diff --git a/aeon/schema/social.py b/aeon/schema/social.py index a16ac252..e5d6efec 100644 --- a/aeon/schema/social.py +++ b/aeon/schema/social.py @@ -1,13 +1,13 @@ """Readers for data relevant to Social experiments.""" -from pathlib import Path -from typing import List, Union import json +from pathlib import Path +import numpy as np import pandas as pd -from aeon import util import aeon.io.reader as _reader +from aeon import util class Pose(_reader.Harp): @@ -25,63 +25,79 @@ def __init__(self, pattern: str, extension: str="bin"): # `pattern` for this reader should typically be '_*' super().__init__(pattern, columns=None, extension=extension) - def read(self, file: Path, ceph_proc_dir: Path=Path("/ceph/aeon/aeon/data/processed")) -> pd.DataFrame: + def read( + self, file: Path, ceph_proc_dir: str | Path = "/ceph/aeon/aeon/data/processed" + ) -> pd.DataFrame: """Reads data from the Harp-binarized tracking file.""" # Get config file from `file`, then bodyparts from config file. model_dir = Path(file.stem.replace("_", "/")).parent config_file_dir = ceph_proc_dir / model_dir - assert config_file_dir.exists(), f"Cannot find model dir {config_file_dir}" + if not config_file_dir.exists(): + raise FileNotFoundError(f"Cannot find model dir {config_file_dir}") config_file = get_config_file(config_file_dir) parts = self.get_bodyparts(config_file) - + # Using bodyparts, assign column names to Harp register values, and read data in default format. columns = ["class", "class_likelihood"] for part in parts: columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"]) self.columns = columns data = super().read(file) - + + # Drop any repeat parts. + unique_parts, unique_idxs = np.unique(parts, return_index=True) + repeat_idxs = np.setdiff1d(np.arange(len(parts)), unique_idxs) + if repeat_idxs: # drop x, y, and likelihood cols for repeat parts (skip first 5 cols) + init_rep_part_col_idx = (repeat_idxs - 1) * 3 + 5 + rep_part_col_idxs = np.concatenate([np.arange(i, i + 3) for i in init_rep_part_col_idx]) + keep_part_col_idxs = np.setdiff1d(np.arange(len(data.columns)), rep_part_col_idxs) + data = data.iloc[:, keep_part_col_idxs] + parts = unique_parts + # Set new columns, and reformat `data`. n_parts = len(parts) - part_data_list = [None] * n_parts + part_data_list = [pd.DataFrame()] * n_parts new_columns = ["class", "class_likelihood", "part", "x", "y", "part_likelihood"] new_data = pd.DataFrame(columns=new_columns) for i, part in enumerate(parts): part_columns = ["class", "class_likelihood", f"{part}_x", f"{part}_y", f"{part}_likelihood"] - part_data = data[part_columns] + part_data = pd.DataFrame(data[part_columns]) part_data.insert(2, "part", part) part_data.columns = new_columns part_data_list[i] = part_data new_data = pd.concat(part_data_list) return new_data.sort_index() - def get_bodyparts(self, file: Path) -> Union[None, List[str]]: + def get_bodyparts(self, file: Path) -> list[str]: """Returns a list of bodyparts from a model's config file.""" - parts = None + parts = [] with open(file) as f: config = json.load(f) if file.stem == "confmap_config": # SLEAP try: heads = config["model"]["heads"] - parts = util.find_nested_key(heads, "part_names") + parts = [util.find_nested_key(heads, "anchor_part")] + parts += util.find_nested_key(heads, "part_names") except KeyError as err: - raise KeyError(f"Cannot find bodyparts in {file}.") from err + if not parts: + raise KeyError(f"Cannot find bodyparts in {file}.") from err return parts def get_config_file( config_file_dir: Path, - config_file_names: List[str]=[ - "confmap_config.json", # SLEAP (add others for other trackers to this list) - ], -): + config_file_names: None | list[str] = None, +) -> Path: """Returns the config file from a model's config directory.""" + if config_file_names is None: + config_file_names = ["confmap_config.json"] # SLEAP (add for other trackers to this list) config_file = None for f in config_file_names: if (config_file_dir / f).exists(): config_file = config_file_dir / f break - assert config_file is not None, f"Cannot find config file in {config_file_dir}" + if config_file is None: + raise FileNotFoundError(f"Cannot find config file in {config_file_dir}") return config_file diff --git a/pyproject.toml b/pyproject.toml index 11528611..075e1ef4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,7 +130,6 @@ reportAssertAlwaysTrue = "error" reportSelfClsParameterName = "error" reportUnusedExpression = "error" reportMatchNotExhaustive = "error" -reportImplicitOverride = "error" reportShadowedImports = "error" # *Note*: we may want to set all 'ReportOptional*' rules to "none", but leaving 'em default for now venvPath = "."