Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic neural networks #660

Merged
merged 36 commits into from
Oct 18, 2024
Merged

Conversation

thibaultdvx
Copy link
Collaborator

@thibaultdvx thibaultdvx commented Oct 1, 2024

I this PR, I suggest the basic neural network architectures that could be trained via the commande line. @camillebrianceau @ravih18 @msolal @HuguesRoy could you please have a look at it?

The implemented networks are the following:

  • MLP: a Multi Layer Perceptron (or Fully Connected Network), where linear, activation, normalization and dropout layers can be customized.
  • ConvEncoder: a Fully Convolutional Encoder, where convolutional, pooling, activation, normalization and dropout layers can be customized.
  • ConvDecoder: a Fully Convolutional Decoder, very similar to ConvEncoder but convolutions are replaced by transposed convolutions and pooling layers by unpooling layers.
  • CNN: a regressor/classifier with first a convolutional part and then a fully connected part. It is a simple aggregation of a ConvEncoder and a MLP, so it is entirely customizable.
  • Generator: the symmetrical network of CNN, made with the aggregation of a MLP and a ConvDecoder.
  • AutoEncoder: a symmetrical autoencoder, built from the aggregation of a CNN (encoder) and the corresponding Generator (decoder).
  • VAE: a Variational AutoEncoder, built on the same idea as AutoEncoder.

Please look at the docstrings for some examples.
All the architectures are tested with unittests.

Feel free to comment the choices I made. To me, the goal is to have sufficiently flexible architectures to enable the user to tune their parameters, without going into too complicated neural networks that could be implemented via the API.
Feel also free to comment my choices on network names.

@thibaultdvx thibaultdvx changed the title Nn module Base neural networks Oct 1, 2024
@thibaultdvx thibaultdvx changed the title Base neural networks Basic neural networks Oct 1, 2024
@thibaultdvx
Copy link
Collaborator Author

thibaultdvx commented Oct 14, 2024

Since last comment, I have added well-know networks from literature:

These networks are also customizable. For example, in DenseNet, the user can specify the number of dense blocks, as well as the number of dense layers in each block.
However, if the user doesn't want to customize the network and wants to have access to the networks with the hyper-parameters used in the reference paper, the user can use factory functions, e.g. get_densenet("DenseNet-121"). Furthermore, if pre-trained weights are available via torchvision, the user can download them.

The implementations mostly rely on MONAI and torchvision. The are all tested with unittests. Please look at the docstrings for some examples.

@thibaultdvx
Copy link
Collaborator Author

thibaultdvx commented Oct 16, 2024

@camillebrianceau @ravih18
All right guys, I'm quite happy with this version of the network module. Let me know if you have any comment.
Otherwise, I will create the config classes associated to the new networks.

@thibaultdvx
Copy link
Collaborator Author

Last commits, I changed the config classes to match the new networks. I also changed the factory function get_network.
Now it can be used in the following way: get_network(name="DenseNet", spatial_dims=1, n_dense_layers=(6, 12, 24)), with first the name of the network and then kwargs to parametrize the network.

@thibaultdvx thibaultdvx marked this pull request as ready for review October 18, 2024 08:11
@thibaultdvx thibaultdvx added the refactoring ClinicaDL refactoring 2024 label Oct 18, 2024
Copy link
Collaborator

@camillebrianceau camillebrianceau left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM !

@thibaultdvx thibaultdvx merged commit 1b49939 into aramis-lab:refactoring Oct 18, 2024
32 checks passed
camillebrianceau pushed a commit to camillebrianceau/clinicadl that referenced this pull request Nov 7, 2024
* add customizable networks (MLP, ConvEncoder, ConvDecoder, CNN, Generator, AutoEncoder, VAE)

* add sota networks (ResNet, DenseNet, SE-ResNet, UNet, Attention-UNet, Vision Transformer)

*update config classes

*update factory function
camillebrianceau pushed a commit to camillebrianceau/clinicadl that referenced this pull request Nov 7, 2024
* add customizable networks (MLP, ConvEncoder, ConvDecoder, CNN, Generator, AutoEncoder, VAE)

* add sota networks (ResNet, DenseNet, SE-ResNet, UNet, Attention-UNet, Vision Transformer)

*update config classes

*update factory function
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
refactoring ClinicaDL refactoring 2024
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants