Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow reading pose model metadata from local folder #421

Merged
merged 12 commits into from
Oct 3, 2024
25 changes: 11 additions & 14 deletions aeon/io/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,12 @@ def read(self, file: Path) -> pd.DataFrame:
"""Reads data from the Harp-binarized tracking file."""
# Get config file from `file`, then bodyparts from config file.
model_dir = Path(*Path(file.stem.replace("_", "/")).parent.parts[-4:])
config_file_dir = Path(self._model_root) / model_dir
config_file_dir = file.parent / model_dir
if not config_file_dir.exists():
raise FileNotFoundError(f"Cannot find model dir {config_file_dir}")
config_file_dir = Path(self._model_root) / model_dir
if not config_file_dir.exists():
raise FileNotFoundError(f"Cannot find model dir {config_file_dir}")
glopesdev marked this conversation as resolved.
Show resolved Hide resolved

config_file = self.get_config_file(config_file_dir)
identities = self.get_class_names(config_file)
parts = self.get_bodyparts(config_file)
Expand Down Expand Up @@ -350,7 +353,7 @@ def read(self, file: Path) -> pd.DataFrame:
parts = unique_parts

# Set new columns, and reformat `data`.
data = self.class_int2str(data, config_file)
data = self.class_int2str(data, identities)
n_parts = len(parts)
part_data_list = [pd.DataFrame()] * n_parts
new_columns = pd.Series(["identity", "identity_likelihood", "part", "x", "y", "part_likelihood"])
Expand Down Expand Up @@ -407,18 +410,12 @@ def get_bodyparts(config_file: Path) -> list[str]:
return parts

@staticmethod
def class_int2str(data: pd.DataFrame, config_file: Path) -> pd.DataFrame:
def class_int2str(data: pd.DataFrame, classes: list[str]) -> pd.DataFrame:
"""Converts a class integer in a tracking data dataframe to its associated string (subject id)."""
if config_file.stem == "confmap_config": # SLEAP
with open(config_file) as f:
config = json.load(f)
try:
heads = config["model"]["heads"]
classes = util.find_nested_key(heads, "classes")
except KeyError as err:
raise KeyError(f"Cannot find classes in {config_file}.") from err
for i, subj in enumerate(classes):
data.loc[data["identity"] == i, "identity"] = subj
identity = data["identity"].astype("string")
for i, subj in enumerate(classes):
identity.loc[data[identity.name] == i] = subj
data[identity.name] = identity
glopesdev marked this conversation as resolved.
Show resolved Hide resolved
return data

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion aeon/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def find_nested_key(obj: dict | list, key: str) -> Any:
found = find_nested_key(v, key)
if found:
return found
else:
elif obj is not None:
glopesdev marked this conversation as resolved.
Show resolved Hide resolved
for item in obj:
found = find_nested_key(item, key)
if found:
Expand Down
Loading