Skip to content

Commit

Permalink
Updates to DynamicsRecordings
Browse files Browse the repository at this point in the history
  • Loading branch information
Danfoa committed Apr 2, 2024
1 parent 024d666 commit bb8f52d
Showing 1 changed file with 39 additions and 17 deletions.
56 changes: 39 additions & 17 deletions morpho_symm/data/DynamicsRecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,20 +148,27 @@ def state_moments(self) -> [np.ndarray, np.ndarray]:
mean, var = np.concatenate(mean), np.concatenate(var)
return mean, var

def action_moments(self) -> [np.ndarray, np.ndarray]:
"""Compute the mean and standard deviation of the action observations."""
if len(self.action_obs) == 0:
return None, None
def state_representations(self) -> list[Representation]:
"""Return the ordered list of representations of the state vector."""
return [self.obs_representations[m] for m in self.state_obs]

mean, var = [], []
for obs_name in self.action_obs:
if obs_name not in self.obs_moments.keys():
self.compute_obs_moments(obs_name)
obs_mean, obs_var = self.obs_moments[obs_name]
mean.append(obs_mean)
var.append(obs_var)
mean, var = np.concatenate(mean), np.concatenate(var)
return mean, var
def get_state_trajs(self, standardize: bool = False):
"""Returns a single array containing the concatenated state observations trajectories.
Given the state observations `self.state_obs` this method concatenates the trajectories of each observation
into a single array of shape [traj, time, state_dim]. If standardize is set to True, the state observations
are standardized to have zero mean and unit variance.
Returns:
A single array containing the concatenated state observations trajectories of shape [traj, time, state_dim].
"""
obs = [self.recordings[obs_name] for obs_name in self.state_obs]
state_trajs = np.concatenate(obs, axis=-1)
if standardize:
state_mean, state_std = self.state_moments()
state_trajs = (state_trajs - state_mean) / state_std

return state_trajs

@staticmethod
def load_from_file(file_path: Path, only_metadata=False,
Expand Down Expand Up @@ -385,18 +392,33 @@ def reduce_dataset_size(recordings: Iterable[DynamicsRecording], train_ratio: fl
new_recordings = {k: v[idx_to_keep] for k, v in r.recordings.items()}
r.recordings = new_recordings


def split_train_val_test(
dyn_recording: DynamicsRecording, partition_sizes=(0.70, 0.15, 0.15)) -> tuple[DynamicsRecording]:
dyn_recording: DynamicsRecording,
partition_sizes=(0.70, 0.15, 0.15),
split_time: bool = True
) -> [DynamicsRecording, DynamicsRecording, DynamicsRecording]:
"""Split the recordings into training, validation and test sets.
Args:
dyn_recording: (DynamicsRecording): The recordings to split.
partition_sizes: (tuple): The sizes of the training, validation and test sets.
split_time: (bool): If True, the split is done along the time dimension. Otherwise, the split is done along the
trajectory dimension.
Returns:
(DynamicsRecording, DynamicsRecording, DynamicsRecording): The training, validation and test sets.
"""

assert np.isclose(np.sum(partition_sizes), 1.0), f"Invalid partition sizes {partition_sizes}"
partitions_names = ["train", "val", "test"]

log.info(f"Partitioning {dyn_recording.description} into train/val/test of sizes {partition_sizes}[%]")
# Ensure all training seeds use the same training data partitions
from morpho_symm.utils.mysc import TemporaryNumpySeed
with TemporaryNumpySeed(10): # Ensure deterministic behavior
with TemporaryNumpySeed(10): # Ensure deterministic behavior
# Decide to keep a ratio of the original trajectories
num_trajs = int(dyn_recording.info['num_traj'])
if num_trajs < 10: # Do not discard entire trajectories, but rather parts of the trajectories
if split_time: # Do not discard entire trajectories, but rather parts of the trajectories
# Take the time horizon from the first observation
sample_obs = dyn_recording.recordings[dyn_recording.state_obs[0]]
if len(sample_obs.shape) == 3: # [traj, time, obs_dim]
Expand Down Expand Up @@ -454,7 +476,7 @@ def get_dynamics_dataset(train_shards: list[Path],
val_shards = [] if val_shards is None else val_shards

if len(test_shards) > 0:
from utils.mysc import compare_dictionaries
from morpho_symm.utils.mysc import compare_dictionaries
test_metadata = DynamicsRecording.load_from_file(test_shards[0], only_metadata=True)
dyn_params_diff = compare_dictionaries(metadata.dynamics_parameters, test_metadata.dynamics_parameters)
assert len(dyn_params_diff) == 0, "Different dynamical systems loaded in train/test sets"
Expand Down

0 comments on commit bb8f52d

Please sign in to comment.