Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loss functions module #640

Merged
merged 59 commits into from
Aug 26, 2024
Merged

Conversation

thibaultdvx
Copy link
Collaborator

@thibaultdvx thibaultdvx commented Jul 22, 2024

In this PR, I suggest a new module for loss functions. The module is very simple and is actually only an intermediary between ClinicaDL and PyTorch's nn module. More precisely, there is:

  • a LossConfig class, with all the parameters useful for PyTorch's loss functions,
  • a factory function get_loss_function that will get a PyTorch loss function and parametrize it thanks to a LossConfig object. This function also returns a config dict with only the parameters relevant to the wanted loss function (whereas a LossFunction object contains the parameters for all loss functions available), because one only wants to save the information related to the wanted loss function in the MAPS.

dependabot bot and others added 30 commits April 22, 2024 10:03
Bumps [sqlparse](https://github.com/andialbrecht/sqlparse) from 0.4.4 to 0.5.0.
- [Changelog](https://github.com/andialbrecht/sqlparse/blob/master/CHANGELOG)
- [Commits](andialbrecht/sqlparse@0.4.4...0.5.0)

---
updated-dependencies:
- dependency-name: sqlparse
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [tqdm](https://github.com/tqdm/tqdm) from 4.66.1 to 4.66.3.
- [Release notes](https://github.com/tqdm/tqdm/releases)
- [Commits](tqdm/tqdm@v4.66.1...v4.66.3)

---
updated-dependencies:
- dependency-name: tqdm
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [werkzeug](https://github.com/pallets/werkzeug) from 3.0.1 to 3.0.3.
- [Release notes](https://github.com/pallets/werkzeug/releases)
- [Changelog](https://github.com/pallets/werkzeug/blob/main/CHANGES.rst)
- [Commits](pallets/werkzeug@3.0.1...3.0.3)

---
updated-dependencies:
- dependency-name: werkzeug
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [jinja2](https://github.com/pallets/jinja) from 3.1.3 to 3.1.4.
- [Release notes](https://github.com/pallets/jinja/releases)
- [Changelog](https://github.com/pallets/jinja/blob/main/CHANGES.rst)
- [Commits](pallets/jinja@3.1.3...3.1.4)

---
updated-dependencies:
- dependency-name: jinja2
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [mlflow](https://github.com/mlflow/mlflow) from 2.10.1 to 2.12.1.
- [Release notes](https://github.com/mlflow/mlflow/releases)
- [Changelog](https://github.com/mlflow/mlflow/blob/master/CHANGELOG.md)
- [Commits](mlflow/mlflow@v2.10.1...v2.12.1)

---
updated-dependencies:
- dependency-name: mlflow
  dependency-type: direct:development
...

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [gunicorn](https://github.com/benoitc/gunicorn) from 21.2.0 to 22.0.0.
- [Release notes](https://github.com/benoitc/gunicorn/releases)
- [Commits](benoitc/gunicorn@21.2.0...22.0.0)

---
updated-dependencies:
- dependency-name: gunicorn
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
updated-dependencies:
- dependency-name: requests
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
* try a simple workflow first

* try running on new ubuntu VM

* fixes

* bump poetry version to 1.8.3

* try removing caching..

* add workflow for testing tsv tools
* try skipping test_tsvtools when PR is in draft mode

* trigger CI

* add a cpu tag to avoid running cpu tests on gpu machines

* run also on refactoring branch
* add test workflow on GPU for train

* fix conda path

* fix conflicting workdir

* only run on non-draft PRs

* run also on refactoring branch
* add workflow for testing interpretation task

* add workflow for testing random search task

* add workflow for testing resume task

* add workflow for testing transfer learning task

* trigger CI

* trigger CI
* add cleaning step to test_tsvtools pipeline

* add test_generate pipeline

* add test_predict pipeline

* add test_prepare_data pipeline

* add test_quality_checks pipeline

* add refactoring target branch, cpu tag, and draft PR filter

* trigger CI
…nt in README (aramis-lab#600)

* update python version used for creating conda env in README

* investigate

* fix
* add no-gpu and adapt-base-dir flag
* Update quality_check.py
* add FileNotFound error in tree
@thibaultdvx thibaultdvx added the refactoring ClinicaDL refactoring 2024 label Jul 22, 2024
@thibaultdvx thibaultdvx changed the title Clinicadl models Loss functions module Jul 22, 2024
@thibaultdvx
Copy link
Collaborator Author

thibaultdvx commented Jul 23, 2024

Update: I added an option "DefaultFromLibrary" in the config class to use by default the default values from PyTorch's loss functions, instead of manually copy/paste these defaults in our config class.

Copy link
Collaborator

@camillebrianceau camillebrianceau left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM !

return loss, config_dict


# TODO : what about them?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

??

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you talking about the other losses?

I've just added more losses from PyTorch in my last commit.

if "weight" in config_dict and config_dict["weight"] is not None:
config_dict_["weight"] = torch.Tensor(config_dict_["weight"])
loss = loss_class(**config_dict_)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the loss is already initialized here, do we need to return a dictionary? Because I thought we wanted to get rid of all the dictionaries.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @camillebrianceau, thanks for the review. You're right, I changed that to put the arguments of the loss function in a config class, rather than a dict.

Even if the loss is initialized, we need the config class to store the config in the MAPS.

The difference between the input and the output config class is that the values set to "DefaultFromLibrary" have been changed with their effective values. It is not essential but better for reproducibility I think (for example in the - unlikely - event that PyTorch changes some default values).

Tell me what do you think!

@thibaultdvx thibaultdvx mentioned this pull request Aug 12, 2024
@thibaultdvx thibaultdvx merged commit 66524a3 into aramis-lab:refactoring Aug 26, 2024
21 checks passed
@thibaultdvx thibaultdvx mentioned this pull request Aug 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
refactoring ClinicaDL refactoring 2024
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants