Skip to content

Commit

Permalink
Avoid iterating over the config file twice
Browse files Browse the repository at this point in the history
  • Loading branch information
glopesdev committed Sep 26, 2024
1 parent 028ffc5 commit 25b7195
Showing 1 changed file with 4 additions and 12 deletions.
16 changes: 4 additions & 12 deletions aeon/io/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def read(self, file: Path) -> pd.DataFrame:
parts = unique_parts

# Set new columns, and reformat `data`.
data = self.class_int2str(data, config_file)
data = self.class_int2str(data, identities)
n_parts = len(parts)
part_data_list = [pd.DataFrame()] * n_parts
new_columns = pd.Series(["identity", "identity_likelihood", "part", "x", "y", "part_likelihood"])
Expand Down Expand Up @@ -410,18 +410,10 @@ def get_bodyparts(config_file: Path) -> list[str]:
return parts

@staticmethod
def class_int2str(data: pd.DataFrame, config_file: Path) -> pd.DataFrame:
def class_int2str(data: pd.DataFrame, classes: list[str]) -> pd.DataFrame:
"""Converts a class integer in a tracking data dataframe to its associated string (subject id)."""
if config_file.stem == "confmap_config": # SLEAP
with open(config_file) as f:
config = json.load(f)
try:
heads = config["model"]["heads"]
classes = util.find_nested_key(heads, "classes")
except KeyError as err:
raise KeyError(f"Cannot find classes in {config_file}.") from err
for i, subj in enumerate(classes):
data.loc[data["identity"] == i, "identity"] = subj
for i, subj in enumerate(classes):
data.loc[data["identity"] == i, "identity"] = subj
return data

@classmethod
Expand Down

0 comments on commit 25b7195

Please sign in to comment.