diff --git a/aeon/io/reader.py b/aeon/io/reader.py index fda0c8af..53927bf4 100644 --- a/aeon/io/reader.py +++ b/aeon/io/reader.py @@ -353,7 +353,7 @@ def read(self, file: Path) -> pd.DataFrame: parts = unique_parts # Set new columns, and reformat `data`. - data = self.class_int2str(data, config_file) + data = self.class_int2str(data, identities) n_parts = len(parts) part_data_list = [pd.DataFrame()] * n_parts new_columns = pd.Series(["identity", "identity_likelihood", "part", "x", "y", "part_likelihood"]) @@ -410,18 +410,10 @@ def get_bodyparts(config_file: Path) -> list[str]: return parts @staticmethod - def class_int2str(data: pd.DataFrame, config_file: Path) -> pd.DataFrame: + def class_int2str(data: pd.DataFrame, classes: list[str]) -> pd.DataFrame: """Converts a class integer in a tracking data dataframe to its associated string (subject id).""" - if config_file.stem == "confmap_config": # SLEAP - with open(config_file) as f: - config = json.load(f) - try: - heads = config["model"]["heads"] - classes = util.find_nested_key(heads, "classes") - except KeyError as err: - raise KeyError(f"Cannot find classes in {config_file}.") from err - for i, subj in enumerate(classes): - data.loc[data["identity"] == i, "identity"] = subj + for i, subj in enumerate(classes): + data.loc[data["identity"] == i, "identity"] = subj return data @classmethod