Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor changes in optim and networks #652

Merged
merged 2 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clinicadl/monai_networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .config import create_network_config
from .config import ImplementedNetworks, NetworkConfig, create_network_config
from .factory import get_network
1 change: 1 addition & 0 deletions clinicadl/monai_networks/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .base import NetworkConfig
from .factory import create_network_config
from .utils.enum import ImplementedNetworks
5 changes: 5 additions & 0 deletions clinicadl/optim/lr_scheduler/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ def get_lr_scheduler(
config_dict_["min_lr"].append(
config_dict["min_lr"]["ELSE"]
) # ELSE must be the last group
else:
default_min_lr = get_args_and_defaults(scheduler_class.__init__)[1][
"min_lr"
]
config_dict_["min_lr"].append(default_min_lr)
scheduler = scheduler_class(optimizer, **config_dict_)

updated_config = LRSchedulerConfig(scheduler=config.scheduler, **config_dict)
Expand Down
98 changes: 7 additions & 91 deletions clinicadl/optim/optimizer/factory.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Any, Dict, Iterable, Iterator, List, Tuple
from typing import Any, Dict, Tuple

import torch
import torch.nn as nn
import torch.optim as optim

from clinicadl.utils.factories import DefaultFromLibrary, get_args_and_defaults

from .config import OptimizerConfig
from .utils import get_params_in_groups, get_params_not_in_groups


def get_optimizer(
Expand Down Expand Up @@ -45,23 +45,16 @@ def get_optimizer(
list_args_groups = network.parameters()
else:
list_args_groups = []
union_groups = set()
args_groups = sorted(args_groups.items()) # order in the list is important
for group, args in args_groups:
params, params_names = _get_params_in_group(network, group)
params, _ = get_params_in_groups(network, group)
args.update({"params": params})
list_args_groups.append(args)
union_groups.update(set(params_names))

other_params = _get_params_not_in_group(network, union_groups)
try:
next(other_params)
except StopIteration: # there is no other param in the network
pass
else:
other_params = _get_params_not_in_group(
network, union_groups
) # reset the generator
other_params, params_names = get_params_not_in_groups(
network, [group for group, _ in args_groups]
)
if len(params_names) > 0:
list_args_groups.append({"params": other_params})

optimizer = optimizer_class(list_args_groups, **args_global)
Expand Down Expand Up @@ -126,80 +119,3 @@ def _regroup_args(
args_global[arg] = value

return args_groups, args_global


def _get_params_in_group(
network: nn.Module, group: str
) -> Tuple[Iterator[torch.Tensor], List[str]]:
"""
Gets the parameters of a specific group of a neural network.

Parameters
----------
network : nn.Module
The neural network.
group : str
The name of the group, e.g. a layer or a block.
If it is a sub-block, the hierarchy should be
specified with "." (see examples).
Will work even if the group is reduced to a base layer
(e.g. group = "dense.weight" or "dense.bias").

Returns
-------
Iterator[torch.Tensor]
A generator that contains the parameters of the group.
List[str]
The name of all the parameters in the group.

Examples
--------
>>> net = nn.Sequential(
OrderedDict(
[
("conv1", nn.Conv2d(1, 1, kernel_size=3)),
("final", nn.Sequential(OrderedDict([("dense1", nn.Linear(10, 10))]))),
]
)
)
>>> generator, params_names = _get_params_in_group(network, "final.dense1")
>>> params_names
["final.dense1.weight", "final.dense1.bias"]
"""
group_hierarchy = group.split(".")
for name in group_hierarchy:
network = getattr(network, name)

try:
params = network.parameters()
params_names = [
".".join([group, name]) for name, _ in network.named_parameters()
]
except AttributeError: # we already reached params
params = (param for param in [network])
params_names = [group]

return params, params_names


def _get_params_not_in_group(
network: nn.Module, group: Iterable[str]
) -> Iterator[torch.Tensor]:
"""
Finds the parameters of a neural networks that
are not in a group.

Parameters
----------
network : nn.Module
The neural network.
group : List[str]
The group of parameters.

Returns
-------
Iterator[torch.Tensor]
A generator of all the parameters that are not in the input
group.
"""
return (param[1] for param in network.named_parameters() if param[0] not in group)
120 changes: 120 additions & 0 deletions clinicadl/optim/optimizer/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from itertools import chain
from typing import Iterator, List, Tuple, Union

import torch
import torch.nn as nn


def get_params_in_groups(
network: nn.Module, groups: Union[str, List[str]]
) -> Tuple[Iterator[torch.Tensor], List[str]]:
"""
Gets the parameters of specific groups of a neural network.

Parameters
----------
network : nn.Module
The neural network.
groups : Union[str, List[str]]
The name of the group(s), e.g. a layer or a block.
If the user refers to a sub-block, the hierarchy should be
specified with "." (see examples).
If a list is passed, the function will output the parameters
of all groups mentioned together.

Returns
-------
Iterator[torch.Tensor]
An iterator that contains the parameters of the group(s).
List[str]
The name of all the parameters in the group(s).

Examples
--------
>>> net = nn.Sequential(
OrderedDict(
[
("conv1", nn.Conv2d(1, 1, kernel_size=3)),
("final", nn.Sequential(OrderedDict([("dense1", nn.Linear(10, 10))]))),
]
)
)
>>> params, params_names = get_params_in_groups(network, "final.dense1")
>>> params_names
["final.dense1.weight", "final.dense1.bias"]
>>> params, params_names = get_params_in_groups(network, ["conv1.weight", "final"])
>>> params_names
["conv1.weight", "final.dense1.weight", "final.dense1.bias"]
"""
if isinstance(groups, str):
groups = [groups]

params = iter(())
params_names = []
for group in groups:
network_ = network
group_hierarchy = group.split(".")
for name in group_hierarchy:
network_ = getattr(network_, name)

try:
params = chain(params, network_.parameters())
params_names += [
".".join([group, name]) for name, _ in network_.named_parameters()
]
except AttributeError: # we already reached params
params = chain(params, (param for param in [network_]))
params_names += [group]

return params, params_names


def get_params_not_in_groups(
network: nn.Module, groups: Union[str, List[str]]
) -> Tuple[Iterator[torch.Tensor], List[str]]:
"""
Gets the parameters not in specific groups of a neural network.

Parameters
----------
network : nn.Module
The neural network.
groups : Union[str, List[str]]
The name of the group(s), e.g. a layer or a block.
If the user refers to a sub-block, the hierarchy should be
specified with "." (see examples).
If a list is passed, the function will output the parameters
that are not in any group of that list.

Returns
-------
Iterator[torch.Tensor]
An iterator that contains the parameters not in the group(s).
List[str]
The name of all the parameters not in the group(s).

Examples
--------
>>> net = nn.Sequential(
OrderedDict(
[
("conv1", nn.Conv2d(1, 1, kernel_size=3)),
("final", nn.Sequential(OrderedDict([("dense1", nn.Linear(10, 10))]))),
]
)
)
>>> params, params_names = get_params_in_groups(network, "final")
>>> params_names
["conv1.weight", "conv1.bias"]
>>> params, params_names = get_params_in_groups(network, ["conv1.bias", "final"])
>>> params_names
["conv1.weight"]
"""
_, in_groups = get_params_in_groups(network, groups)
params = (
param[1] for param in network.named_parameters() if param[0] not in in_groups
)
params_names = list(
param[0] for param in network.named_parameters() if param[0] not in in_groups
)
return params, params_names
12 changes: 9 additions & 3 deletions tests/unittests/optim/lr_scheduler/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def test_get_lr_scheduler():
[
("linear1", nn.Linear(4, 3)),
("linear2", nn.Linear(3, 2)),
("linear3", nn.Linear(2, 1)),
]
)
)
Expand All @@ -29,6 +30,9 @@ def test_get_lr_scheduler():
{
"params": network.linear2.parameters(),
},
{
"params": network.linear3.parameters(),
},
],
lr=10.0,
)
Expand Down Expand Up @@ -58,7 +62,7 @@ def test_get_lr_scheduler():
assert scheduler.threshold == 1e-1
assert scheduler.threshold_mode == "rel"
assert scheduler.cooldown == 3
assert scheduler.min_lrs == [0.1, 0.01]
assert scheduler.min_lrs == [0.1, 0.01, 0.0]
assert scheduler.eps == 1e-8

assert updated_config.scheduler == "ReduceLROnPlateau"
Expand All @@ -71,12 +75,14 @@ def test_get_lr_scheduler():
assert updated_config.min_lr == {"linear2": 0.01, "linear1": 0.1}
assert updated_config.eps == 1e-8

network.add_module("linear3", nn.Linear(3, 2))
optimizer.add_param_group({"params": network.linear3.parameters()})
config.min_lr = {"ELSE": 1, "linear2": 0.01, "linear1": 0.1}
scheduler, updated_config = get_lr_scheduler(optimizer, config)
assert scheduler.min_lrs == [0.1, 0.01, 1]

config.min_lr = 1
scheduler, updated_config = get_lr_scheduler(optimizer, config)
assert scheduler.min_lrs == [1.0, 1.0, 1.0]

config = LRSchedulerConfig()
scheduler, updated_config = get_lr_scheduler(optimizer, config)
assert isinstance(scheduler, LambdaLR)
Expand Down
Loading
Loading