Skip to content

Commit

Permalink
Disabled default batch norm on equivariant modules.
Browse files Browse the repository at this point in the history
  • Loading branch information
Danfoa committed Nov 2, 2023
1 parent dc6d8f7 commit d21e795
Show file tree
Hide file tree
Showing 7 changed files with 12 additions and 5 deletions.
2 changes: 1 addition & 1 deletion morpho_symm/cfg/dataset/com_momentum.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ angular_momentum: True
standarize: True

batch_size: 256
max_epochs: 600
max_epochs: 300
log_every_n_epochs: 0.5

#samples: 100000 # Dataset size.
Expand Down
1 change: 1 addition & 0 deletions morpho_symm/cfg/model/emlp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ defaults:
model_type: 'EMLP'
lr: 2.4e-3
num_layers: 4
batch_norm: False
activation: 'elu'
num_channels: 128
1 change: 1 addition & 0 deletions morpho_symm/cfg/model/mlp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ defaults:
model_type: 'MLP'
lr: 2.4e-3
num_layers: 4
batch_norm: False
activation: 'elu'
num_channels: 128
6 changes: 4 additions & 2 deletions morpho_symm/nn/EMLP.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self,
bias: bool = True,
activation: Union[str, EquivariantModule] = "ELU",
head_with_activation: bool = False,
batch_norm: bool = True):
batch_norm: bool = False):
"""Constructor of an Equivariant Multi-Layer Perceptron (EMLP) model.
This utility class allows to easily instanciate a G-equivariant MLP architecture. As a convention, we assume
Expand Down Expand Up @@ -61,6 +61,8 @@ def __init__(self,
self.group = self.gspace.fibergroup
self.num_layers = num_layers

if batch_norm:
log.warning("Equivariant Batch norm affects the performance of the model. Dont use if for now!!!")
# Check if the network is a G-invariant function (i.e., out rep is composed only of the trivial representation)
out_irreps = set(out_type.representation.irreps)
if len(out_irreps) == 1 and self.group.trivial_representation.id == list(out_irreps)[0]:
Expand Down Expand Up @@ -96,7 +98,7 @@ def __init__(self,
block.add_module(f"linear_{n}: in={layer_in_type.size}-out={layer_out_type.size}",
escnn.nn.Linear(layer_in_type, layer_out_type, bias=bias))
if batch_norm:
block.add_module(f"batchnorm_{n}", escnn.nn.IIDBatchNorm1d(layer_out_type)),
block.add_module(f"batchnorm_{n}", escnn.nn.IIDBatchNorm1d(layer_out_type, )),
block.add_module(f"act_{n}", activation)

self.net.add_module(f"block_{n}", block)
Expand Down
2 changes: 1 addition & 1 deletion morpho_symm/nn/MLP.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self,
num_hidden_units: int = 64,
num_layers: int = 3,
bias: bool = True,
batch_norm: bool = True,
batch_norm: bool = False,
head_with_activation: bool = False,
activation: Union[torch.nn.Module, List[torch.nn.Module]] = torch.nn.ReLU,
init_mode="fan_in"):
Expand Down
3 changes: 3 additions & 0 deletions paper/train_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def get_model(cfg: DictConfig, in_field_type=None, out_field_type=None):
model = EMLP(in_type=in_field_type,
out_type=out_field_type,
num_layers=cfg.num_layers,
batch_norm=cfg.batch_norm,
num_hidden_units=cfg.num_channels,
activation=cfg.activation,
bias=cfg.bias)
Expand All @@ -59,6 +60,7 @@ def get_model(cfg: DictConfig, in_field_type=None, out_field_type=None):
model = MLP(in_dim=in_field_type.size,
out_dim=out_field_type.size,
num_layers=cfg.num_layers,
batch_norm=cfg.batch_norm,
init_mode=cfg.init_mode,
num_hidden_units=cfg.num_channels,
bias=cfg.bias,
Expand Down Expand Up @@ -214,6 +216,7 @@ def main(cfg: DictConfig):
name=run_name,
group=f'{cfg.exp_name}',
job_type='debug' if (cfg.debug or cfg.debug_loops) else None)
wandb_logger.watch(pl_model)

log.info("\n\nInitiating Training\n\n")
trainer = Trainer(accelerator='cuda' if torch.cuda.is_available() and cfg.device != 'cpu' else 'cpu',
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ allow-direct-references = true # We use custom branch of escnn with some devel

[project]
name = "morpho_symm"
version = "0.1.2"
version = "0.1.3"
keywords = ["morphological symmetry", "locomotion", "dynamical systems", "robot symmetries", "symmetry"]
description = "Tools for the identification, study, and exploitation of morphological symmetries in locomoting dynamical systems"
readme = "README.md"
Expand Down

0 comments on commit d21e795

Please sign in to comment.