From 8e7bcd8745c309f11e27fc2012a91eb54a3da43c Mon Sep 17 00:00:00 2001 From: Thibault de Varax <154365476+thibaultdvx@users.noreply.github.com> Date: Mon, 23 Sep 2024 17:41:17 +0200 Subject: [PATCH] Minor changes in optim and networks (#652) * correction of special cases in optimizer * add Implemented Networks in init --- clinicadl/monai_networks/__init__.py | 2 +- clinicadl/monai_networks/config/__init__.py | 1 + clinicadl/optim/lr_scheduler/factory.py | 5 + clinicadl/optim/optimizer/factory.py | 98 +------------- clinicadl/optim/optimizer/utils.py | 120 ++++++++++++++++++ .../optim/lr_scheduler/test_factory.py | 12 +- .../unittests/optim/optimizer/test_factory.py | 66 ++-------- tests/unittests/optim/optimizer/test_utils.py | 109 ++++++++++++++++ 8 files changed, 264 insertions(+), 149 deletions(-) create mode 100644 clinicadl/optim/optimizer/utils.py create mode 100644 tests/unittests/optim/optimizer/test_utils.py diff --git a/clinicadl/monai_networks/__init__.py b/clinicadl/monai_networks/__init__.py index 95c87507a..1d74473d4 100644 --- a/clinicadl/monai_networks/__init__.py +++ b/clinicadl/monai_networks/__init__.py @@ -1,2 +1,2 @@ -from .config import create_network_config +from .config import ImplementedNetworks, NetworkConfig, create_network_config from .factory import get_network diff --git a/clinicadl/monai_networks/config/__init__.py b/clinicadl/monai_networks/config/__init__.py index ed03a909f..10b8795dc 100644 --- a/clinicadl/monai_networks/config/__init__.py +++ b/clinicadl/monai_networks/config/__init__.py @@ -1,2 +1,3 @@ +from .base import NetworkConfig from .factory import create_network_config from .utils.enum import ImplementedNetworks diff --git a/clinicadl/optim/lr_scheduler/factory.py b/clinicadl/optim/lr_scheduler/factory.py index f07b07f32..a26948deb 100644 --- a/clinicadl/optim/lr_scheduler/factory.py +++ b/clinicadl/optim/lr_scheduler/factory.py @@ -49,6 +49,11 @@ def get_lr_scheduler( config_dict_["min_lr"].append( config_dict["min_lr"]["ELSE"] ) # ELSE must be the last group + else: + default_min_lr = get_args_and_defaults(scheduler_class.__init__)[1][ + "min_lr" + ] + config_dict_["min_lr"].append(default_min_lr) scheduler = scheduler_class(optimizer, **config_dict_) updated_config = LRSchedulerConfig(scheduler=config.scheduler, **config_dict) diff --git a/clinicadl/optim/optimizer/factory.py b/clinicadl/optim/optimizer/factory.py index 15ffd8a52..3afd6a848 100644 --- a/clinicadl/optim/optimizer/factory.py +++ b/clinicadl/optim/optimizer/factory.py @@ -1,12 +1,12 @@ -from typing import Any, Dict, Iterable, Iterator, List, Tuple +from typing import Any, Dict, Tuple -import torch import torch.nn as nn import torch.optim as optim from clinicadl.utils.factories import DefaultFromLibrary, get_args_and_defaults from .config import OptimizerConfig +from .utils import get_params_in_groups, get_params_not_in_groups def get_optimizer( @@ -45,23 +45,16 @@ def get_optimizer( list_args_groups = network.parameters() else: list_args_groups = [] - union_groups = set() args_groups = sorted(args_groups.items()) # order in the list is important for group, args in args_groups: - params, params_names = _get_params_in_group(network, group) + params, _ = get_params_in_groups(network, group) args.update({"params": params}) list_args_groups.append(args) - union_groups.update(set(params_names)) - other_params = _get_params_not_in_group(network, union_groups) - try: - next(other_params) - except StopIteration: # there is no other param in the network - pass - else: - other_params = _get_params_not_in_group( - network, union_groups - ) # reset the generator + other_params, params_names = get_params_not_in_groups( + network, [group for group, _ in args_groups] + ) + if len(params_names) > 0: list_args_groups.append({"params": other_params}) optimizer = optimizer_class(list_args_groups, **args_global) @@ -126,80 +119,3 @@ def _regroup_args( args_global[arg] = value return args_groups, args_global - - -def _get_params_in_group( - network: nn.Module, group: str -) -> Tuple[Iterator[torch.Tensor], List[str]]: - """ - Gets the parameters of a specific group of a neural network. - - Parameters - ---------- - network : nn.Module - The neural network. - group : str - The name of the group, e.g. a layer or a block. - If it is a sub-block, the hierarchy should be - specified with "." (see examples). - Will work even if the group is reduced to a base layer - (e.g. group = "dense.weight" or "dense.bias"). - - Returns - ------- - Iterator[torch.Tensor] - A generator that contains the parameters of the group. - List[str] - The name of all the parameters in the group. - - Examples - -------- - >>> net = nn.Sequential( - OrderedDict( - [ - ("conv1", nn.Conv2d(1, 1, kernel_size=3)), - ("final", nn.Sequential(OrderedDict([("dense1", nn.Linear(10, 10))]))), - ] - ) - ) - >>> generator, params_names = _get_params_in_group(network, "final.dense1") - >>> params_names - ["final.dense1.weight", "final.dense1.bias"] - """ - group_hierarchy = group.split(".") - for name in group_hierarchy: - network = getattr(network, name) - - try: - params = network.parameters() - params_names = [ - ".".join([group, name]) for name, _ in network.named_parameters() - ] - except AttributeError: # we already reached params - params = (param for param in [network]) - params_names = [group] - - return params, params_names - - -def _get_params_not_in_group( - network: nn.Module, group: Iterable[str] -) -> Iterator[torch.Tensor]: - """ - Finds the parameters of a neural networks that - are not in a group. - - Parameters - ---------- - network : nn.Module - The neural network. - group : List[str] - The group of parameters. - - Returns - ------- - Iterator[torch.Tensor] - A generator of all the parameters that are not in the input - group. - """ - return (param[1] for param in network.named_parameters() if param[0] not in group) diff --git a/clinicadl/optim/optimizer/utils.py b/clinicadl/optim/optimizer/utils.py new file mode 100644 index 000000000..669f9a8e4 --- /dev/null +++ b/clinicadl/optim/optimizer/utils.py @@ -0,0 +1,120 @@ +from itertools import chain +from typing import Iterator, List, Tuple, Union + +import torch +import torch.nn as nn + + +def get_params_in_groups( + network: nn.Module, groups: Union[str, List[str]] +) -> Tuple[Iterator[torch.Tensor], List[str]]: + """ + Gets the parameters of specific groups of a neural network. + + Parameters + ---------- + network : nn.Module + The neural network. + groups : Union[str, List[str]] + The name of the group(s), e.g. a layer or a block. + If the user refers to a sub-block, the hierarchy should be + specified with "." (see examples). + If a list is passed, the function will output the parameters + of all groups mentioned together. + + Returns + ------- + Iterator[torch.Tensor] + An iterator that contains the parameters of the group(s). + List[str] + The name of all the parameters in the group(s). + + Examples + -------- + >>> net = nn.Sequential( + OrderedDict( + [ + ("conv1", nn.Conv2d(1, 1, kernel_size=3)), + ("final", nn.Sequential(OrderedDict([("dense1", nn.Linear(10, 10))]))), + ] + ) + ) + >>> params, params_names = get_params_in_groups(network, "final.dense1") + >>> params_names + ["final.dense1.weight", "final.dense1.bias"] + >>> params, params_names = get_params_in_groups(network, ["conv1.weight", "final"]) + >>> params_names + ["conv1.weight", "final.dense1.weight", "final.dense1.bias"] + """ + if isinstance(groups, str): + groups = [groups] + + params = iter(()) + params_names = [] + for group in groups: + network_ = network + group_hierarchy = group.split(".") + for name in group_hierarchy: + network_ = getattr(network_, name) + + try: + params = chain(params, network_.parameters()) + params_names += [ + ".".join([group, name]) for name, _ in network_.named_parameters() + ] + except AttributeError: # we already reached params + params = chain(params, (param for param in [network_])) + params_names += [group] + + return params, params_names + + +def get_params_not_in_groups( + network: nn.Module, groups: Union[str, List[str]] +) -> Tuple[Iterator[torch.Tensor], List[str]]: + """ + Gets the parameters not in specific groups of a neural network. + + Parameters + ---------- + network : nn.Module + The neural network. + groups : Union[str, List[str]] + The name of the group(s), e.g. a layer or a block. + If the user refers to a sub-block, the hierarchy should be + specified with "." (see examples). + If a list is passed, the function will output the parameters + that are not in any group of that list. + + Returns + ------- + Iterator[torch.Tensor] + An iterator that contains the parameters not in the group(s). + List[str] + The name of all the parameters not in the group(s). + + Examples + -------- + >>> net = nn.Sequential( + OrderedDict( + [ + ("conv1", nn.Conv2d(1, 1, kernel_size=3)), + ("final", nn.Sequential(OrderedDict([("dense1", nn.Linear(10, 10))]))), + ] + ) + ) + >>> params, params_names = get_params_in_groups(network, "final") + >>> params_names + ["conv1.weight", "conv1.bias"] + >>> params, params_names = get_params_in_groups(network, ["conv1.bias", "final"]) + >>> params_names + ["conv1.weight"] + """ + _, in_groups = get_params_in_groups(network, groups) + params = ( + param[1] for param in network.named_parameters() if param[0] not in in_groups + ) + params_names = list( + param[0] for param in network.named_parameters() if param[0] not in in_groups + ) + return params, params_names diff --git a/tests/unittests/optim/lr_scheduler/test_factory.py b/tests/unittests/optim/lr_scheduler/test_factory.py index c59759fcd..76df845cd 100644 --- a/tests/unittests/optim/lr_scheduler/test_factory.py +++ b/tests/unittests/optim/lr_scheduler/test_factory.py @@ -17,6 +17,7 @@ def test_get_lr_scheduler(): [ ("linear1", nn.Linear(4, 3)), ("linear2", nn.Linear(3, 2)), + ("linear3", nn.Linear(2, 1)), ] ) ) @@ -29,6 +30,9 @@ def test_get_lr_scheduler(): { "params": network.linear2.parameters(), }, + { + "params": network.linear3.parameters(), + }, ], lr=10.0, ) @@ -58,7 +62,7 @@ def test_get_lr_scheduler(): assert scheduler.threshold == 1e-1 assert scheduler.threshold_mode == "rel" assert scheduler.cooldown == 3 - assert scheduler.min_lrs == [0.1, 0.01] + assert scheduler.min_lrs == [0.1, 0.01, 0.0] assert scheduler.eps == 1e-8 assert updated_config.scheduler == "ReduceLROnPlateau" @@ -71,12 +75,14 @@ def test_get_lr_scheduler(): assert updated_config.min_lr == {"linear2": 0.01, "linear1": 0.1} assert updated_config.eps == 1e-8 - network.add_module("linear3", nn.Linear(3, 2)) - optimizer.add_param_group({"params": network.linear3.parameters()}) config.min_lr = {"ELSE": 1, "linear2": 0.01, "linear1": 0.1} scheduler, updated_config = get_lr_scheduler(optimizer, config) assert scheduler.min_lrs == [0.1, 0.01, 1] + config.min_lr = 1 + scheduler, updated_config = get_lr_scheduler(optimizer, config) + assert scheduler.min_lrs == [1.0, 1.0, 1.0] + config = LRSchedulerConfig() scheduler, updated_config = get_lr_scheduler(optimizer, config) assert isinstance(scheduler, LambdaLR) diff --git a/tests/unittests/optim/optimizer/test_factory.py b/tests/unittests/optim/optimizer/test_factory.py index 6387b7f5a..7dbf149e9 100644 --- a/tests/unittests/optim/optimizer/test_factory.py +++ b/tests/unittests/optim/optimizer/test_factory.py @@ -82,7 +82,7 @@ def test_get_optimizer(network): assert not updated_config.maximize assert not updated_config.differentiable - # special cases + # special case : only ELSE config = OptimizerConfig( optimizer="Adagrad", lr_decay={"ELSE": 100}, @@ -91,6 +91,7 @@ def test_get_optimizer(network): assert len(optimizer.param_groups) == 1 assert optimizer.param_groups[0]["lr_decay"] == 100 + # special case : the params mentioned form all the network config = OptimizerConfig( optimizer="Adagrad", lr_decay={"conv1": 100, "dense1": 10, "final": 1}, @@ -98,6 +99,16 @@ def test_get_optimizer(network): optimizer, _ = get_optimizer(network, config) assert len(optimizer.param_groups) == 3 + # special case : no ELSE mentioned + config = OptimizerConfig( + optimizer="Adagrad", + lr_decay={"conv1": 100}, + ) + optimizer, _ = get_optimizer(network, config) + assert len(optimizer.param_groups) == 2 + assert optimizer.param_groups[0]["lr_decay"] == 100 + assert optimizer.param_groups[1]["lr_decay"] == 0 + def test_regroup_args(): from clinicadl.optim.optimizer.factory import _regroup_args @@ -123,56 +134,3 @@ def test_regroup_args(): {"weight_decay": {"params_0": 0.0, "params_1": 1.0}} ) assert len(args_global) == 0 - - -def test_get_params_in_block(network): - import torch - - from clinicadl.optim.optimizer.factory import _get_params_in_group - - generator, list_layers = _get_params_in_group(network, "dense1") - assert next(iter(generator)).shape == torch.Size((10, 10)) - assert next(iter(generator)).shape == torch.Size((10,)) - assert sorted(list_layers) == sorted(["dense1.weight", "dense1.bias"]) - - generator, list_layers = _get_params_in_group(network, "dense1.weight") - assert next(iter(generator)).shape == torch.Size((10, 10)) - assert sum(1 for _ in generator) == 0 - assert sorted(list_layers) == sorted(["dense1.weight"]) - - generator, list_layers = _get_params_in_group(network, "final.dense3") - assert next(iter(generator)).shape == torch.Size((3, 5)) - assert next(iter(generator)).shape == torch.Size((3,)) - assert sorted(list_layers) == sorted(["final.dense3.weight", "final.dense3.bias"]) - - generator, list_layers = _get_params_in_group(network, "final") - assert sum(1 for _ in generator) == 4 - assert sorted(list_layers) == sorted( - [ - "final.dense2.weight", - "final.dense2.bias", - "final.dense3.weight", - "final.dense3.bias", - ] - ) - - -def test_find_params_not_in_group(network): - import torch - - from clinicadl.optim.optimizer.factory import _get_params_not_in_group - - params = _get_params_not_in_group( - network, - [ - "final.dense2.weight", - "final.dense2.bias", - "conv1.bias", - "final.dense3.weight", - "dense1.weight", - "dense1.bias", - ], - ) - assert next(iter(params)).shape == torch.Size((1, 1, 3, 3)) - assert next(iter(params)).shape == torch.Size((3,)) - assert sum(1 for _ in params) == 0 # no more params diff --git a/tests/unittests/optim/optimizer/test_utils.py b/tests/unittests/optim/optimizer/test_utils.py new file mode 100644 index 000000000..afa06a5d0 --- /dev/null +++ b/tests/unittests/optim/optimizer/test_utils.py @@ -0,0 +1,109 @@ +from collections import OrderedDict + +import pytest +import torch +import torch.nn as nn + + +@pytest.fixture +def network(): + network = nn.Sequential( + OrderedDict( + [ + ("conv1", nn.Conv2d(1, 1, kernel_size=3)), + ("dense1", nn.Linear(10, 10)), + ] + ) + ) + network.add_module( + "final", + nn.Sequential( + OrderedDict([("dense2", nn.Linear(10, 5)), ("dense3", nn.Linear(5, 3))]) + ), + ) + return network + + +def test_get_params_in_groups(network): + import torch + + from clinicadl.optim.optimizer.utils import get_params_in_groups + + iterator, list_layers = get_params_in_groups(network, "dense1") + assert next(iter(iterator)).shape == torch.Size((10, 10)) + assert next(iter(iterator)).shape == torch.Size((10,)) + assert sorted(list_layers) == sorted(["dense1.weight", "dense1.bias"]) + + iterator, list_layers = get_params_in_groups(network, "dense1.weight") + assert next(iter(iterator)).shape == torch.Size((10, 10)) + assert sum(1 for _ in iterator) == 0 + assert sorted(list_layers) == sorted(["dense1.weight"]) + + iterator, list_layers = get_params_in_groups(network, "final.dense3") + assert next(iter(iterator)).shape == torch.Size((3, 5)) + assert next(iter(iterator)).shape == torch.Size((3,)) + assert sorted(list_layers) == sorted(["final.dense3.weight", "final.dense3.bias"]) + + iterator, list_layers = get_params_in_groups(network, "final") + assert sum(1 for _ in iterator) == 4 + assert sorted(list_layers) == sorted( + [ + "final.dense2.weight", + "final.dense2.bias", + "final.dense3.weight", + "final.dense3.bias", + ] + ) + + iterator, list_layers = get_params_in_groups(network, ["dense1.weight", "final"]) + assert sum(1 for _ in iterator) == 5 + assert sorted(list_layers) == sorted( + [ + "dense1.weight", + "final.dense2.weight", + "final.dense2.bias", + "final.dense3.weight", + "final.dense3.bias", + ] + ) + + # chrck with numbers + network_bis = nn.Sequential(nn.Linear(10, 2), nn.Linear(2, 1)) + iterator, list_layers = get_params_in_groups(network_bis, "0.bias") + assert next(iter(iterator)).shape == torch.Size((2,)) + assert sorted(list_layers) == sorted(["0.bias"]) + + +def test_find_params_not_in_group(network): + import torch + + from clinicadl.optim.optimizer.utils import get_params_not_in_groups + + iterator, list_layers = get_params_not_in_groups( + network, + [ + "final", + "conv1.weight", + ], + ) + assert next(iter(iterator)).shape == torch.Size((1,)) + assert next(iter(iterator)).shape == torch.Size((10, 10)) + assert sum(1 for _ in iterator) == 1 + assert sorted(list_layers) == sorted( + [ + "conv1.bias", + "dense1.weight", + "dense1.bias", + ] + ) + + iterator, list_layers = get_params_not_in_groups(network, "final") + assert sum(1 for _ in iterator) == 4 + assert sorted(list_layers) == sorted( + [ + "conv1.weight", + "conv1.bias", + "dense1.weight", + "dense1.bias", + ] + )