Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model.fit VS Trainer.fit #19

Open
ahmedmohammed107 opened this issue Mar 18, 2024 · 1 comment
Open

Model.fit VS Trainer.fit #19

ahmedmohammed107 opened this issue Mar 18, 2024 · 1 comment
Assignees
Labels
enhancement New feature or request

Comments

@ahmedmohammed107
Copy link
Contributor

Didact Scenario:

As Diana was testing the new DIDACT workflow, she registered the model to MLflow with 1 epoch to make sure the code is running. She then loaded the model from MLflow and changed the number of epochs in the configuration file to actually train the model. But since the model is registered with the number of epochs to train, she had to override model['training_config']['num_epochs'].

Problems:

  1. The model is provided with a lot of information that has nothing to do with the model itself. If the model is intended to be used for inference for example, it wouldn't make sense to provide the learning rate in the constructor. It can be provided as an argument to the model.fit method but there could be a lot of other parameters that need to be provided as well.
  2. Additionally, the fit method is where logging is performed and snapshots are saved. Again this adds to the Model class additional behavior that is irrelevant to the model itself.

Proposed Solution ===> trainer.fit instead of model.fit:

It is widely adopted in PyTorch to separate the model from the training. The model would only have init and forward (or call) methods. The Trainer class would do all the heavy lifting. It takes the model, training/validation data, and all training configs (learning rate, optimizer_config, scheduler_config, early stopping config). It also does the logging and knows where to save snapshots. It can serve as the interface with MLflow by saving: 1) parameters before training starts, 2) Metrics and artifacts after training training is complete.

@ahmedmohammed107 ahmedmohammed107 self-assigned this Mar 18, 2024
@schr476
Copy link
Contributor

schr476 commented Apr 11, 2024

@ahmedmohammed107, please focus on the parser first.

@schr476 schr476 added the enhancement New feature or request label Apr 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants