-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Loss functions module #640
Loss functions module #640
Conversation
Bumps [sqlparse](https://github.com/andialbrecht/sqlparse) from 0.4.4 to 0.5.0. - [Changelog](https://github.com/andialbrecht/sqlparse/blob/master/CHANGELOG) - [Commits](andialbrecht/sqlparse@0.4.4...0.5.0) --- updated-dependencies: - dependency-name: sqlparse dependency-type: indirect ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [tqdm](https://github.com/tqdm/tqdm) from 4.66.1 to 4.66.3. - [Release notes](https://github.com/tqdm/tqdm/releases) - [Commits](tqdm/tqdm@v4.66.1...v4.66.3) --- updated-dependencies: - dependency-name: tqdm dependency-type: indirect ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [werkzeug](https://github.com/pallets/werkzeug) from 3.0.1 to 3.0.3. - [Release notes](https://github.com/pallets/werkzeug/releases) - [Changelog](https://github.com/pallets/werkzeug/blob/main/CHANGES.rst) - [Commits](pallets/werkzeug@3.0.1...3.0.3) --- updated-dependencies: - dependency-name: werkzeug dependency-type: indirect ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [jinja2](https://github.com/pallets/jinja) from 3.1.3 to 3.1.4. - [Release notes](https://github.com/pallets/jinja/releases) - [Changelog](https://github.com/pallets/jinja/blob/main/CHANGES.rst) - [Commits](pallets/jinja@3.1.3...3.1.4) --- updated-dependencies: - dependency-name: jinja2 dependency-type: indirect ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [mlflow](https://github.com/mlflow/mlflow) from 2.10.1 to 2.12.1. - [Release notes](https://github.com/mlflow/mlflow/releases) - [Changelog](https://github.com/mlflow/mlflow/blob/master/CHANGELOG.md) - [Commits](mlflow/mlflow@v2.10.1...v2.12.1) --- updated-dependencies: - dependency-name: mlflow dependency-type: direct:development ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [gunicorn](https://github.com/benoitc/gunicorn) from 21.2.0 to 22.0.0. - [Release notes](https://github.com/benoitc/gunicorn/releases) - [Commits](benoitc/gunicorn@21.2.0...22.0.0) --- updated-dependencies: - dependency-name: gunicorn dependency-type: indirect ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
updated-dependencies: - dependency-name: requests dependency-type: indirect ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
* try a simple workflow first * try running on new ubuntu VM * fixes * bump poetry version to 1.8.3 * try removing caching.. * add workflow for testing tsv tools
* try skipping test_tsvtools when PR is in draft mode * trigger CI * add a cpu tag to avoid running cpu tests on gpu machines * run also on refactoring branch
* add test workflow on GPU for train * fix conda path * fix conflicting workdir * only run on non-draft PRs * run also on refactoring branch
* add workflow for testing interpretation task * add workflow for testing random search task * add workflow for testing resume task * add workflow for testing transfer learning task * trigger CI * trigger CI
* add cleaning step to test_tsvtools pipeline * add test_generate pipeline * add test_predict pipeline * add test_prepare_data pipeline * add test_quality_checks pipeline * add refactoring target branch, cpu tag, and draft PR filter * trigger CI
…nt in README (aramis-lab#600) * update python version used for creating conda env in README * investigate * fix
* add no-gpu and adapt-base-dir flag
* Update quality_check.py
* add FileNotFound error in tree
Update: I added an option "DefaultFromLibrary" in the config class to use by default the default values from PyTorch's loss functions, instead of manually copy/paste these defaults in our config class. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, LGTM !
clinicadl/losses/factory.py
Outdated
return loss, config_dict | ||
|
||
|
||
# TODO : what about them? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
??
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you talking about the other losses?
I've just added more losses from PyTorch in my last commit.
if "weight" in config_dict and config_dict["weight"] is not None: | ||
config_dict_["weight"] = torch.Tensor(config_dict_["weight"]) | ||
loss = loss_class(**config_dict_) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As the loss is already initialized here, do we need to return a dictionary? Because I thought we wanted to get rid of all the dictionaries.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @camillebrianceau, thanks for the review. You're right, I changed that to put the arguments of the loss function in a config class, rather than a dict.
Even if the loss is initialized, we need the config class to store the config in the MAPS.
The difference between the input and the output config class is that the values set to "DefaultFromLibrary" have been changed with their effective values. It is not essential but better for reproducibility I think (for example in the - unlikely - event that PyTorch changes some default values).
Tell me what do you think!
In this PR, I suggest a new module for loss functions. The module is very simple and is actually only an intermediary between
ClinicaDL
and PyTorch'snn
module. More precisely, there is:LossConfig
class, with all the parameters useful for PyTorch's loss functions,get_loss_function
that will get a PyTorch loss function and parametrize it thanks to aLossConfig
object. This function also returns a config dict with only the parameters relevant to the wanted loss function (whereas aLossFunction
object contains the parameters for all loss functions available), because one only wants to save the information related to the wanted loss function in the MAPS.