From d21e7955160dbb6e59bb86d0ef85d41ddd4cfcee Mon Sep 17 00:00:00 2001 From: Daniel Ordonez Date: Thu, 2 Nov 2023 15:13:57 +0100 Subject: [PATCH] Disabled default batch norm on equivariant modules. --- morpho_symm/cfg/dataset/com_momentum.yaml | 2 +- morpho_symm/cfg/model/emlp.yaml | 1 + morpho_symm/cfg/model/mlp.yaml | 1 + morpho_symm/nn/EMLP.py | 6 ++++-- morpho_symm/nn/MLP.py | 2 +- paper/train_supervised.py | 3 +++ pyproject.toml | 2 +- 7 files changed, 12 insertions(+), 5 deletions(-) diff --git a/morpho_symm/cfg/dataset/com_momentum.yaml b/morpho_symm/cfg/dataset/com_momentum.yaml index ffba92f..dc59397 100644 --- a/morpho_symm/cfg/dataset/com_momentum.yaml +++ b/morpho_symm/cfg/dataset/com_momentum.yaml @@ -11,7 +11,7 @@ angular_momentum: True standarize: True batch_size: 256 -max_epochs: 600 +max_epochs: 300 log_every_n_epochs: 0.5 #samples: 100000 # Dataset size. diff --git a/morpho_symm/cfg/model/emlp.yaml b/morpho_symm/cfg/model/emlp.yaml index 8656185..92f560e 100644 --- a/morpho_symm/cfg/model/emlp.yaml +++ b/morpho_symm/cfg/model/emlp.yaml @@ -5,5 +5,6 @@ defaults: model_type: 'EMLP' lr: 2.4e-3 num_layers: 4 +batch_norm: False activation: 'elu' num_channels: 128 diff --git a/morpho_symm/cfg/model/mlp.yaml b/morpho_symm/cfg/model/mlp.yaml index ad6e651..b86b80e 100644 --- a/morpho_symm/cfg/model/mlp.yaml +++ b/morpho_symm/cfg/model/mlp.yaml @@ -5,5 +5,6 @@ defaults: model_type: 'MLP' lr: 2.4e-3 num_layers: 4 +batch_norm: False activation: 'elu' num_channels: 128 \ No newline at end of file diff --git a/morpho_symm/nn/EMLP.py b/morpho_symm/nn/EMLP.py index 5ab9858..919458b 100644 --- a/morpho_symm/nn/EMLP.py +++ b/morpho_symm/nn/EMLP.py @@ -22,7 +22,7 @@ def __init__(self, bias: bool = True, activation: Union[str, EquivariantModule] = "ELU", head_with_activation: bool = False, - batch_norm: bool = True): + batch_norm: bool = False): """Constructor of an Equivariant Multi-Layer Perceptron (EMLP) model. This utility class allows to easily instanciate a G-equivariant MLP architecture. As a convention, we assume @@ -61,6 +61,8 @@ def __init__(self, self.group = self.gspace.fibergroup self.num_layers = num_layers + if batch_norm: + log.warning("Equivariant Batch norm affects the performance of the model. Dont use if for now!!!") # Check if the network is a G-invariant function (i.e., out rep is composed only of the trivial representation) out_irreps = set(out_type.representation.irreps) if len(out_irreps) == 1 and self.group.trivial_representation.id == list(out_irreps)[0]: @@ -96,7 +98,7 @@ def __init__(self, block.add_module(f"linear_{n}: in={layer_in_type.size}-out={layer_out_type.size}", escnn.nn.Linear(layer_in_type, layer_out_type, bias=bias)) if batch_norm: - block.add_module(f"batchnorm_{n}", escnn.nn.IIDBatchNorm1d(layer_out_type)), + block.add_module(f"batchnorm_{n}", escnn.nn.IIDBatchNorm1d(layer_out_type, )), block.add_module(f"act_{n}", activation) self.net.add_module(f"block_{n}", block) diff --git a/morpho_symm/nn/MLP.py b/morpho_symm/nn/MLP.py index 028b234..8b78abc 100644 --- a/morpho_symm/nn/MLP.py +++ b/morpho_symm/nn/MLP.py @@ -15,7 +15,7 @@ def __init__(self, num_hidden_units: int = 64, num_layers: int = 3, bias: bool = True, - batch_norm: bool = True, + batch_norm: bool = False, head_with_activation: bool = False, activation: Union[torch.nn.Module, List[torch.nn.Module]] = torch.nn.ReLU, init_mode="fan_in"): diff --git a/paper/train_supervised.py b/paper/train_supervised.py index 661861f..abd7fa9 100644 --- a/paper/train_supervised.py +++ b/paper/train_supervised.py @@ -50,6 +50,7 @@ def get_model(cfg: DictConfig, in_field_type=None, out_field_type=None): model = EMLP(in_type=in_field_type, out_type=out_field_type, num_layers=cfg.num_layers, + batch_norm=cfg.batch_norm, num_hidden_units=cfg.num_channels, activation=cfg.activation, bias=cfg.bias) @@ -59,6 +60,7 @@ def get_model(cfg: DictConfig, in_field_type=None, out_field_type=None): model = MLP(in_dim=in_field_type.size, out_dim=out_field_type.size, num_layers=cfg.num_layers, + batch_norm=cfg.batch_norm, init_mode=cfg.init_mode, num_hidden_units=cfg.num_channels, bias=cfg.bias, @@ -214,6 +216,7 @@ def main(cfg: DictConfig): name=run_name, group=f'{cfg.exp_name}', job_type='debug' if (cfg.debug or cfg.debug_loops) else None) + wandb_logger.watch(pl_model) log.info("\n\nInitiating Training\n\n") trainer = Trainer(accelerator='cuda' if torch.cuda.is_available() and cfg.device != 'cpu' else 'cpu', diff --git a/pyproject.toml b/pyproject.toml index bebdb27..f0d23ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ allow-direct-references = true # We use custom branch of escnn with some devel [project] name = "morpho_symm" -version = "0.1.2" +version = "0.1.3" keywords = ["morphological symmetry", "locomotion", "dynamical systems", "robot symmetries", "symmetry"] description = "Tools for the identification, study, and exploitation of morphological symmetries in locomoting dynamical systems" readme = "README.md"