Skip to content

Commit

Permalink
Fix EMLP hidden dim bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Danfoa committed Oct 27, 2023
1 parent ca1eead commit dc6d8f7
Show file tree
Hide file tree
Showing 12 changed files with 168 additions and 242 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ output/
.ipynb_checkpoints/
*.log

morpho_symm/robot_harmonic_decomposition.py
morpho_symm/robot_harmonic_decomposition.py
/morpho_symm/experiments/
8 changes: 6 additions & 2 deletions morpho_symm/cfg/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ defaults:
- dataset: com_momentum
- model: emlp
- robot: solo
- override hydra/launcher: joblib


# TODO: Make distinctions between. Trainer Args, Model Psecific Args, Program Args
#robot: 'solo12'
Expand All @@ -18,12 +20,14 @@ debug_loops: False
use_volatile: False

# Hydra configuration _________

hydra:
run:
dir: ./experiments/${hydra.job.name}/${hydra.job.override_dirname}
dir: ./morpho_symm/experiments/${hydra.job.name}/${hydra.job.override_dirname}

job:
# TODO: Reorganize output dir
chdir: True # Create a new directory for each run
name: ${dataset.job_name}
num: ${seed}
env_set:
Expand Down Expand Up @@ -56,7 +60,7 @@ hydra:
- use_volatile

sweep:
dir: ./experiments/${hydra.job.name}/
dir: ./morpho_symm/experiments/${hydra.job.name}/
subdir: ${hydra.job.override_dirname}

job_logging:
Expand Down
3 changes: 2 additions & 1 deletion morpho_symm/cfg/dataset/com_momentum.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ name: 'com_momentum'

data_folder: "dataset/com_momentum"

augment: true
augment: False
angular_momentum: True

standarize: True
Expand All @@ -14,6 +14,7 @@ batch_size: 256
max_epochs: 600
log_every_n_epochs: 0.5

#samples: 100000 # Dataset size.
samples: 100000 # Dataset size.
train_ratio: 0.7
test_ratio: 0.15
Expand Down
3 changes: 1 addition & 2 deletions morpho_symm/cfg/model/emlp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,5 @@ defaults:
model_type: 'EMLP'
lr: 2.4e-3
num_layers: 4
activation: 'elu'
num_channels: 128
inv_dims_scale: 1.0
fine_tune_num_layers: 1
3 changes: 2 additions & 1 deletion morpho_symm/cfg/model/mlp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ defaults:

model_type: 'MLP'
lr: 2.4e-3
num_layers: 2
num_layers: 4
activation: 'elu'
num_channels: 128
59 changes: 26 additions & 33 deletions morpho_symm/datasets/com_momentum/com_momentum.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
import copy
import logging
import pathlib
import random
import time
from typing import Optional, Union

import numpy as np
import torch
import torch.nn.functional as F
from escnn import gspaces
from escnn.group import Representation
from escnn.group import Group, Representation
from escnn.nn import FieldType
from omegaconf import DictConfig
from scipy.sparse import issparse
from torch.utils.data import Dataset
from torch.utils.data._utils.collate import default_collate

import morpho_symm.utils.pybullet_visual_utils
from morpho_symm.robots.PinBulletWrapper import PinBulletWrapper

# from morpho_symm.utils.algebra_utils import dense
from morpho_symm.utils.robot_utils import load_symmetric_system
Expand Down Expand Up @@ -75,10 +74,8 @@ def __init__(self, robot_cfg, type='train',
self.angular_momentum = angular_momentum
self.normalize = True if isinstance(standarizer, Standarizer) else standarizer

robot, gspace, in_feature_type, out_feature_type = self.define_input_output_field_types(robot_cfg)
self.robot, self.gspace = robot, gspace
self.in_feature_type, self.out_feature_type = in_feature_type, out_feature_type
self.G = self.gspace.fibergroup
# Load robot, symmetry group and input-output field types/representations
self.robot, self.G, self.in_type, self.out_type = self.define_input_output_field_types(robot_cfg)

self._pb = None # GUI debug
self.augment = augment if isinstance(augment, bool) else False
Expand Down Expand Up @@ -126,8 +123,8 @@ def __init__(self, robot_cfg, type='train',

if isinstance(augment, str) and augment.lower() == "hard":
for g in self.G.elements[1:]:
rep_X = self.in_feature_type.fiber_representation(g).to(self.X.device)
rep_Y = self.out_feature_type.fiber_representation(g).to(self.Y.device)
rep_X = self.in_type.fiber_representation(g).to(self.X.device)
rep_Y = self.out_type.fiber_representation(g).to(self.Y.device)
gX = (rep_X @ self.X.T).T
gY = (rep_Y @ self.Y.T).T
self.X = torch.vstack([self.X, gX])
Expand Down Expand Up @@ -157,8 +154,6 @@ def compute_normalization(self, data_matrix, rep_data: Optional[Representation]
mean: Empirical expected value over the data matrix (dim(x),)
std: Empirical variance over the data matrix (dim(x),)
"""
X = data_matrix

if rep_data is None:
X_mean = np.mean(data_matrix, axis=0)
X_std = np.std(data_matrix, axis=0)
Expand All @@ -168,7 +163,7 @@ def compute_normalization(self, data_matrix, rep_data: Optional[Representation]
# X_iso = Q_inv @ X[..., None]
# X_mean =

return X_mean, X_std, #Y_mean, Y_std
return X_mean, X_std, # Y_mean, Y_std

def test_equivariance(self):
trials = 10
Expand All @@ -184,11 +179,11 @@ def test_equivariance(self):
y_true = [y]

non_equivariance_detected = False
rep_X = self.in_type.representation
rep_Y = self.out_type.representation
# Get all possible group actions
for g_in, g_out in zip(self.in_feature_type.G.discrete_actions[1:],
self.out_feature_type.G.discrete_actions[1:]):
g_in, g_out = (g_in.todense(), g_out.todense()) if issparse(g_in) else (g_in, g_out)
gx, gy = np.asarray(g_in) @ x, np.asarray(g_out) @ y
for g in self.G.elements:
gx, gy = rep_X(g) @ x, rep_Y(g) @ y
x_orbit.append(gx)
y_orbit.append(gy)
assert gx.dtype == x.dtype, (gx.dtype, x.dtype)
Expand Down Expand Up @@ -253,11 +248,10 @@ def collate_fn(self, batch):
x_batch, y_batch = default_collate(batch)

if self.augment: # Sample uniformly among symmetry actions including identity
g_in, g_out = random.choice(self.t_group_actions)
g_x_batch = torch.matmul(x_batch.unsqueeze(1), g_in.unsqueeze(0).to(x_batch.dtype)).squeeze()
g_y_batch = torch.matmul(y_batch.unsqueeze(1), g_out.unsqueeze(0).to(y_batch.dtype)).squeeze()
# x, xx = x_batch[0], g_x_batch[0]
# y, yy = y_batch[0], g_y_batch[0]
g = self.G.sample()

g_x_batch = self.in_type.transform_fibers(x_batch, g)
g_y_batch = self.out_type.transform_fibers(y_batch, g)
x_batch, y_batch = g_x_batch, g_y_batch
return x_batch.to(self.dtype), y_batch.to(self.dtype)

Expand Down Expand Up @@ -352,10 +346,10 @@ def ensure_dataset_existance(self):
np.random.seed(29081995)
# Get joint limits.
dq_max = np.asarray(self.robot.velocity_limits)
dq_max = np.minimum(dq_max, 2 * np.pi)
dq_max = np.minimum(dq_max, np.pi)
q_min, q_max = self.robot.joint_pos_limits
np.minimum(q_min, -2 * np.pi)
q_max = np.minimum(q_max, 2 * np.pi)
q_min = np.maximum(q_min, -np.pi)
q_max = np.minimum(q_max, np.pi)

x = np.zeros((self._samples, self.robot.n_js * 2))
y = np.zeros((self._samples, 6))
Expand All @@ -370,13 +364,13 @@ def ensure_dataset_existance(self):
# and dynamics are completely equivariant. So we make the gt the avg of the augmented predictions.
ys_pin = [y]
for g in self.G.elements[1:]:
gx = np.squeeze(self.in_feature_type.representation(g) @ x.T).T
gx = np.squeeze(self.in_type.representation(g) @ x.T).T
# gy = np.squeeze(y @ g_out)
gy_pin = np.zeros((self._samples, 6))
# Generate random configuration samples.
for i, x_sample in enumerate(gx):
gy_pin[i, :] = self.get_hg(x_sample[:self.robot.nq - 7], x_sample[self.robot.nq - 7:])
ys_pin.append((self.out_feature_type.representation(
ys_pin.append((self.out_type.representation(
~g) @ gy_pin.T).T) # inverse is not needed for the groups we use (C2, V4).

y_pin_avg = np.mean(ys_pin, axis=0)
Expand Down Expand Up @@ -425,7 +419,7 @@ def to(self, device):
self.X.to(device)
self.Y.to(device)

def define_input_output_field_types(self, robot_cfg: DictConfig):
def define_input_output_field_types(self, robot_cfg: DictConfig) -> (PinBulletWrapper, Group, FieldType, FieldType):
"""Define the input-output symmetry representations for the CoM function g·y = f(g·x) | g ∈ G.
Define the symmetry representations for the Center of Mass momentum y := (l, k) ∈ R^3 x R^3, where l is the
Expand All @@ -444,22 +438,21 @@ def define_input_output_field_types(self, robot_cfg: DictConfig):
output_type (FieldType): The output feature field type, describing the system CoM (hg) and its symmetry
transformations.
"""
robot, symmetry_space = load_symmetric_system(robot_cfg=robot_cfg, debug=False)
G = symmetry_space.fibergroup
robot, G = load_symmetric_system(robot_cfg=robot_cfg, debug=False)
# For this application we compute the CoM w.r.t base frame, meaning that we ignore the fiber group Ed in which
# the system evolves in:
gspace = gspaces.no_base_space(G)
# Get the relevant representations.
rep_Q_js = G.representations["Q_js"]
rep_TqQ_js = G.representations["TqQ_js"]
rep_O3 = G.representations["Od"]
rep_O3_pseudo = G.representations["Od_pseudo"]
rep_R3 = G.representations["Rd"]
rep_R3_pseudo = G.representations["Rd_pseudo"]

# Rep for x := [q, dq] ∈ Q_js x TqQ_js => ρ_Q_js(g) ⊕ ρ_TqQ_js(g) | g ∈ G
in_type = FieldType(gspace, [rep_Q_js, rep_TqQ_js])

# Rep for center of mass momentum y := [l, k] ∈ R3 x R3 => ρ_R3(g) ⊕ ρ_R3pseudo(g) | g ∈ G
out_type = FieldType(gspace, [rep_O3, rep_O3_pseudo])
out_type = FieldType(gspace, [rep_R3, rep_R3_pseudo])

# TODO: handle subgroup cases.
# if robot_cfg.gens_ids is not None and self.dataset_type in ["test", "val"]:
Expand All @@ -474,7 +467,7 @@ def define_input_output_field_types(self, robot_cfg: DictConfig):
# return robot, rep_data_in, rep_data_out

log.info(f"[{self.dataset_type}] Dataset using the symmetry group {type(G)}")
return robot, gspace, in_type, out_type
return robot, G, in_type, out_type

def __repr__(self):
msg = f"CoM Dataset: [{self.robot.robot_name}]-{self.dataset_type}-Aug:{self.augment}" \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def log_cm_img(label, classes, figsize=(4, 3), annot=True):
metrics.update(balanced_acc_dir)
metrics.update(individual_state_metrics)

model.log_metrics(metrics, prefix=prefix)
model.log_metrics(metrics, suffix=prefix)
model.train()

def decimal2binary(self, x):
Expand Down
27 changes: 15 additions & 12 deletions morpho_symm/nn/EMLP.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import escnn
import numpy as np
import torch
from escnn.nn import EquivariantModule, FieldType
from escnn.nn import EquivariantModule, FieldType, GeometricTensor

from morpho_symm.nn.EquivariantModules import IsotypicBasis

Expand Down Expand Up @@ -93,7 +93,8 @@ def __init__(self,
layer_out_type = hidden_type

block = escnn.nn.SequentialModule()
block.add_module(f"linear_{n}", escnn.nn.Linear(layer_in_type, layer_out_type, bias=bias))
block.add_module(f"linear_{n}: in={layer_in_type.size}-out={layer_out_type.size}",
escnn.nn.Linear(layer_in_type, layer_out_type, bias=bias))
if batch_norm:
block.add_module(f"batchnorm_{n}", escnn.nn.IIDBatchNorm1d(layer_out_type)),
block.add_module(f"act_{n}", activation)
Expand Down Expand Up @@ -126,7 +127,7 @@ def __init__(self,
# Test the entire model is equivariant.
# self.net.check_equivariance()

def forward(self, x):
def forward(self, x: GeometricTensor) -> GeometricTensor:
"""Forward pass of the EMLP model."""
equivariant_features = self.net(x)
if self.invariant_fn:
Expand All @@ -143,26 +144,28 @@ def reset_parameters(self, init_mode=None):
raise NotImplementedError()

@staticmethod
def get_activation(activation, in_type: FieldType, desired_hidden_units: int):
def get_activation(activation, in_type: FieldType, desired_hidden_units: int) -> EquivariantModule:
gspace = in_type.gspace
group = gspace.fibergroup

grid_kwargs = EMLP.get_group_kwargs(group)

unique_irreps = set(in_type.irreps)
unique_irreps_dim = sum([group.irrep(*id).size for id in set(in_type.irreps)])
scale = in_type.size // unique_irreps_dim
channels = int(np.ceil(desired_hidden_units // unique_irreps_dim // scale))
channels = int(np.ceil(desired_hidden_units // unique_irreps_dim))
if "identity" in activation.lower():
raise NotImplementedError("Identity activation not implemented yet")
# return escnn.nn.IdentityModule()
else:
return escnn.nn.FourierPointwise(gspace,
channels=channels,
irreps=list(unique_irreps),
function=f"p_{activation.lower()}",
inplace=True,
**grid_kwargs)
act = escnn.nn.FourierPointwise(gspace,
channels=channels,
irreps=list(unique_irreps),
function=f"p_{activation.lower()}",
inplace=True,
**grid_kwargs)
assert (act.out_type.size - desired_hidden_units) <= unique_irreps_dim, \
f"out_type.size {act.out_type.size} - des_hidden_units {desired_hidden_units} > {unique_irreps_dim}"
return act

@staticmethod
def get_group_kwargs(group: escnn.group.Group):
Expand Down
Loading

0 comments on commit dc6d8f7

Please sign in to comment.