diff --git a/morpho_symm/data/DynamicsRecording.py b/morpho_symm/data/DynamicsRecording.py index 582846e..6809c74 100644 --- a/morpho_symm/data/DynamicsRecording.py +++ b/morpho_symm/data/DynamicsRecording.py @@ -33,31 +33,27 @@ class DynamicsRecording: # Map from observation name to the observation moments (mean, var) of the recordings. obs_moments: Dict[str, tuple] = field(default_factory=dict) - def save_to_file(self, file_path: Path): - # Store representations and groups without serializing - if len(self.obs_representations) > 0: - self._obs_rep_irreps = {} - self._obs_rep_names = {} - self._obs_rep_Q = {} - for k, rep in self.obs_representations.items(): - self._obs_rep_irreps[k] = rep.irreps if rep is not None else None - self._obs_rep_names[k] = rep.name if rep is not None else None - self._obs_rep_Q[k] = rep.change_of_basis if rep is not None else None - group = self.obs_representations[self.state_obs[0]].group - self._group_keys = group._keys - self._group_name = group.__class__.__name__ - # Remove non-serializable objects - del self.obs_representations - self.dynamics_parameters.pop('group', None) - - with file_path.with_suffix(".pkl").open('wb') as file: - pickle.dump(self, file, protocol=pickle.HIGHEST_PROTOCOL) - @property def obs_dims(self): """Dictionary providing the map between observation name and observation dimension.""" return {k: v.shape[-1] for k, v in self.recordings.items()} + def state_representations(self) -> list[Representation]: + """Return the ordered list of representations of the state vector.""" + return [self.obs_representations[m] for m in self.state_obs] + + def state_moments(self) -> [np.ndarray, np.ndarray]: + """Compute the mean and standard deviation of the state observations.""" + mean, var = [], [] + for obs_name in self.state_obs: + if obs_name not in self.obs_moments.keys(): + self.compute_obs_moments(obs_name) + obs_mean, obs_var = self.obs_moments[obs_name] + mean.append(obs_mean) + var.append(obs_var) + mean, var = np.concatenate(mean), np.concatenate(var) + return mean, var + def compute_obs_moments(self, obs_name: str) -> [np.ndarray, np.ndarray]: """Compute the mean and standard deviation of observations.""" assert obs_name in self.recordings.keys(), f"Observation {obs_name} not found in recordings" @@ -115,19 +111,19 @@ def compute_obs_moments(self, obs_name: str) -> [np.ndarray, np.ndarray]: Cov = Q @ np.diag(var_irrep_basis) @ Q_inv var = np.diagonal(Cov) - # # TODO: Move this check to Unit test as it is computationally demanding to check this at runtime. - # # Ensure the mean is equivalent to computing the mean of the orbit of the recording under the group action - # aug_obs = [] - # for g in G.elements: - # g_obs = np.einsum('...ij,...j->...i', rep_obs(g), obs_original_basis) - # aug_obs.append(g_obs) - # - # aug_obs = np.concatenate(aug_obs, axis=0) # Append over the trajectory dimension - # mean_emp = np.mean(aug_obs, axis=(0, 1)) - # assert np.allclose(mean, mean_emp, rtol=1e-3, atol=1e-3), f"Mean {mean} != {mean_emp}" - # - # var_emp = np.var(aug_obs, axis=(0, 1)) - # assert np.allclose(var, var_emp, rtol=1e-2, atol=1e-2), f"Var {var} != {var_emp}" + # TODO: Move this check to Unit test as it is computationally demanding to check this at runtime. + # Ensure the mean is equivalent to computing the mean of the orbit of the recording under the group action + aug_obs = [] + for g in G.elements: + g_obs = np.einsum('...ij,...j->...i', rep_obs(g), obs_original_basis) + aug_obs.append(g_obs) + + aug_obs = np.concatenate(aug_obs, axis=0) # Append over the trajectory dimension + mean_emp = np.mean(aug_obs, axis=(0, 1)) + assert np.allclose(mean, mean_emp, rtol=1e-3, atol=1e-3), f"Mean {mean} != {mean_emp}" + + var_emp = np.var(aug_obs, axis=(0, 1)) + assert np.allclose(var, var_emp, rtol=1e-2, atol=1e-2), f"Var {var} != {var_emp}" else: mean = np.mean(np.asarray(self.recordings[obs_name]), axis=(0, 1)) var = np.var(np.asarray(self.recordings[obs_name]), axis=(0, 1)) @@ -136,22 +132,6 @@ def compute_obs_moments(self, obs_name: str) -> [np.ndarray, np.ndarray]: self.obs_moments[obs_name] = mean, var - def state_moments(self) -> [np.ndarray, np.ndarray]: - """Compute the mean and standard deviation of the state observations.""" - mean, var = [], [] - for obs_name in self.state_obs: - if obs_name not in self.obs_moments.keys(): - self.compute_obs_moments(obs_name) - obs_mean, obs_var = self.obs_moments[obs_name] - mean.append(obs_mean) - var.append(obs_var) - mean, var = np.concatenate(mean), np.concatenate(var) - return mean, var - - def state_representations(self) -> list[Representation]: - """Return the ordered list of representations of the state vector.""" - return [self.obs_representations[m] for m in self.state_obs] - def get_state_trajs(self, standardize: bool = False): """Returns a single array containing the concatenated state observations trajectories. @@ -165,11 +145,32 @@ def get_state_trajs(self, standardize: bool = False): obs = [self.recordings[obs_name] for obs_name in self.state_obs] state_trajs = np.concatenate(obs, axis=-1) if standardize: - state_mean, state_std = self.state_moments() + state_mean, var = self.state_moments() + state_std = np.sqrt(var) state_trajs = (state_trajs - state_mean) / state_std return state_trajs + def save_to_file(self, file_path: Path): + # Store representations and groups without serializing + if len(self.obs_representations) > 0: + self._obs_rep_irreps = {} + self._obs_rep_names = {} + self._obs_rep_Q = {} + for k, rep in self.obs_representations.items(): + self._obs_rep_irreps[k] = rep.irreps if rep is not None else None + self._obs_rep_names[k] = rep.name if rep is not None else None + self._obs_rep_Q[k] = rep.change_of_basis if rep is not None else None + group = self.obs_representations[self.state_obs[0]].group + self._group_keys = group._keys + self._group_name = group.__class__.__name__ + # Remove non-serializable objects + del self.obs_representations + self.dynamics_parameters.pop('group', None) + + with file_path.with_suffix(".pkl").open('wb') as file: + pickle.dump(self, file, protocol=pickle.HIGHEST_PROTOCOL) + @staticmethod def load_from_file(file_path: Path, only_metadata=False, obs_names: Optional[Iterable[str]] = None) -> 'DynamicsRecording': @@ -180,6 +181,7 @@ def load_from_file(file_path: Path, only_metadata=False, else: data_obs_names = list(data.recordings.keys()) if obs_names is not None: + data.state_obs = tuple(obs_names) irrelevant_obs = [k for k in data_obs_names if k not in obs_names and obs_names is not None] for k in irrelevant_obs: del data.recordings[k] diff --git a/morpho_symm/data/mini_cheetah/raysim_recordings/uneven_easy/forward_minus_0_4/n_trajs=1-frames=12453-train.pkl b/morpho_symm/data/mini_cheetah/raysim_recordings/uneven_easy/forward_minus_0_4/n_trajs=1-frames=12453-train.pkl deleted file mode 100644 index 4e5ce6f..0000000 Binary files a/morpho_symm/data/mini_cheetah/raysim_recordings/uneven_easy/forward_minus_0_4/n_trajs=1-frames=12453-train.pkl and /dev/null differ diff --git a/morpho_symm/data/mini_cheetah/raysim_recordings/uneven_easy/forward_minus_0_4/n_trajs=8-frames=2668-test.pkl b/morpho_symm/data/mini_cheetah/raysim_recordings/uneven_easy/forward_minus_0_4/n_trajs=8-frames=2668-test.pkl deleted file mode 100644 index 870ee19..0000000 Binary files a/morpho_symm/data/mini_cheetah/raysim_recordings/uneven_easy/forward_minus_0_4/n_trajs=8-frames=2668-test.pkl and /dev/null differ diff --git a/morpho_symm/data/mini_cheetah/raysim_recordings/uneven_easy/forward_minus_0_4/n_trajs=8-frames=2668-val.pkl b/morpho_symm/data/mini_cheetah/raysim_recordings/uneven_easy/forward_minus_0_4/n_trajs=8-frames=2668-val.pkl deleted file mode 100644 index dfa3197..0000000 Binary files a/morpho_symm/data/mini_cheetah/raysim_recordings/uneven_easy/forward_minus_0_4/n_trajs=8-frames=2668-val.pkl and /dev/null differ