Skip to content

Commit

Permalink
Fix issue with standarization
Browse files Browse the repository at this point in the history
  • Loading branch information
Danfoa committed Apr 3, 2024
1 parent bb8f52d commit 82a3573
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 50 deletions.
102 changes: 52 additions & 50 deletions morpho_symm/data/DynamicsRecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,31 +33,27 @@ class DynamicsRecording:
# Map from observation name to the observation moments (mean, var) of the recordings.
obs_moments: Dict[str, tuple] = field(default_factory=dict)

def save_to_file(self, file_path: Path):
# Store representations and groups without serializing
if len(self.obs_representations) > 0:
self._obs_rep_irreps = {}
self._obs_rep_names = {}
self._obs_rep_Q = {}
for k, rep in self.obs_representations.items():
self._obs_rep_irreps[k] = rep.irreps if rep is not None else None
self._obs_rep_names[k] = rep.name if rep is not None else None
self._obs_rep_Q[k] = rep.change_of_basis if rep is not None else None
group = self.obs_representations[self.state_obs[0]].group
self._group_keys = group._keys
self._group_name = group.__class__.__name__
# Remove non-serializable objects
del self.obs_representations
self.dynamics_parameters.pop('group', None)

with file_path.with_suffix(".pkl").open('wb') as file:
pickle.dump(self, file, protocol=pickle.HIGHEST_PROTOCOL)

@property
def obs_dims(self):
"""Dictionary providing the map between observation name and observation dimension."""
return {k: v.shape[-1] for k, v in self.recordings.items()}

def state_representations(self) -> list[Representation]:
"""Return the ordered list of representations of the state vector."""
return [self.obs_representations[m] for m in self.state_obs]

def state_moments(self) -> [np.ndarray, np.ndarray]:
"""Compute the mean and standard deviation of the state observations."""
mean, var = [], []
for obs_name in self.state_obs:
if obs_name not in self.obs_moments.keys():
self.compute_obs_moments(obs_name)
obs_mean, obs_var = self.obs_moments[obs_name]
mean.append(obs_mean)
var.append(obs_var)
mean, var = np.concatenate(mean), np.concatenate(var)
return mean, var

def compute_obs_moments(self, obs_name: str) -> [np.ndarray, np.ndarray]:
"""Compute the mean and standard deviation of observations."""
assert obs_name in self.recordings.keys(), f"Observation {obs_name} not found in recordings"
Expand Down Expand Up @@ -115,19 +111,19 @@ def compute_obs_moments(self, obs_name: str) -> [np.ndarray, np.ndarray]:
Cov = Q @ np.diag(var_irrep_basis) @ Q_inv
var = np.diagonal(Cov)

# # TODO: Move this check to Unit test as it is computationally demanding to check this at runtime.
# # Ensure the mean is equivalent to computing the mean of the orbit of the recording under the group action
# aug_obs = []
# for g in G.elements:
# g_obs = np.einsum('...ij,...j->...i', rep_obs(g), obs_original_basis)
# aug_obs.append(g_obs)
#
# aug_obs = np.concatenate(aug_obs, axis=0) # Append over the trajectory dimension
# mean_emp = np.mean(aug_obs, axis=(0, 1))
# assert np.allclose(mean, mean_emp, rtol=1e-3, atol=1e-3), f"Mean {mean} != {mean_emp}"
#
# var_emp = np.var(aug_obs, axis=(0, 1))
# assert np.allclose(var, var_emp, rtol=1e-2, atol=1e-2), f"Var {var} != {var_emp}"
# TODO: Move this check to Unit test as it is computationally demanding to check this at runtime.
# Ensure the mean is equivalent to computing the mean of the orbit of the recording under the group action
aug_obs = []
for g in G.elements:
g_obs = np.einsum('...ij,...j->...i', rep_obs(g), obs_original_basis)
aug_obs.append(g_obs)

aug_obs = np.concatenate(aug_obs, axis=0) # Append over the trajectory dimension
mean_emp = np.mean(aug_obs, axis=(0, 1))
assert np.allclose(mean, mean_emp, rtol=1e-3, atol=1e-3), f"Mean {mean} != {mean_emp}"

var_emp = np.var(aug_obs, axis=(0, 1))
assert np.allclose(var, var_emp, rtol=1e-2, atol=1e-2), f"Var {var} != {var_emp}"
else:
mean = np.mean(np.asarray(self.recordings[obs_name]), axis=(0, 1))
var = np.var(np.asarray(self.recordings[obs_name]), axis=(0, 1))
Expand All @@ -136,22 +132,6 @@ def compute_obs_moments(self, obs_name: str) -> [np.ndarray, np.ndarray]:

self.obs_moments[obs_name] = mean, var

def state_moments(self) -> [np.ndarray, np.ndarray]:
"""Compute the mean and standard deviation of the state observations."""
mean, var = [], []
for obs_name in self.state_obs:
if obs_name not in self.obs_moments.keys():
self.compute_obs_moments(obs_name)
obs_mean, obs_var = self.obs_moments[obs_name]
mean.append(obs_mean)
var.append(obs_var)
mean, var = np.concatenate(mean), np.concatenate(var)
return mean, var

def state_representations(self) -> list[Representation]:
"""Return the ordered list of representations of the state vector."""
return [self.obs_representations[m] for m in self.state_obs]

def get_state_trajs(self, standardize: bool = False):
"""Returns a single array containing the concatenated state observations trajectories.
Expand All @@ -165,11 +145,32 @@ def get_state_trajs(self, standardize: bool = False):
obs = [self.recordings[obs_name] for obs_name in self.state_obs]
state_trajs = np.concatenate(obs, axis=-1)
if standardize:
state_mean, state_std = self.state_moments()
state_mean, var = self.state_moments()
state_std = np.sqrt(var)
state_trajs = (state_trajs - state_mean) / state_std

return state_trajs

def save_to_file(self, file_path: Path):
# Store representations and groups without serializing
if len(self.obs_representations) > 0:
self._obs_rep_irreps = {}
self._obs_rep_names = {}
self._obs_rep_Q = {}
for k, rep in self.obs_representations.items():
self._obs_rep_irreps[k] = rep.irreps if rep is not None else None
self._obs_rep_names[k] = rep.name if rep is not None else None
self._obs_rep_Q[k] = rep.change_of_basis if rep is not None else None
group = self.obs_representations[self.state_obs[0]].group
self._group_keys = group._keys
self._group_name = group.__class__.__name__
# Remove non-serializable objects
del self.obs_representations
self.dynamics_parameters.pop('group', None)

with file_path.with_suffix(".pkl").open('wb') as file:
pickle.dump(self, file, protocol=pickle.HIGHEST_PROTOCOL)

@staticmethod
def load_from_file(file_path: Path, only_metadata=False,
obs_names: Optional[Iterable[str]] = None) -> 'DynamicsRecording':
Expand All @@ -180,6 +181,7 @@ def load_from_file(file_path: Path, only_metadata=False,
else:
data_obs_names = list(data.recordings.keys())
if obs_names is not None:
data.state_obs = tuple(obs_names)
irrelevant_obs = [k for k in data_obs_names if k not in obs_names and obs_names is not None]
for k in irrelevant_obs:
del data.recordings[k]
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit 82a3573

Please sign in to comment.