-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added ability to save and load from huggingfce #83
Conversation
…gs to work with this
@@ -1,5 +1,6 @@ | |||
# Ignores saved predictors | |||
predictors/*/trained_models/ | |||
predictors/trained_models |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Default cache dir for our huggingface models
if not (load_path / "config.json").exists() or \ | ||
not (load_path / "model.pt").exists() or \ | ||
not (load_path / "scaler.joblib").exists(): | ||
raise FileNotFoundError("Model files not found in path.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check to see all the files were downloaded properly before loading
hf_args["local_dir"] = local_dir | ||
snapshot_download(repo_id=path_or_url, **hf_args) | ||
|
||
return cls.load(Path(local_dir)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implementation of from_pretrained.
- We check disk for the model
- If it doesn't exist we download from hub. We default our local save dir to predictors/trained_models.
- We load our model from the local file
self.features = model_config.get("features", None) | ||
self.label = model_config.get("label", None) | ||
|
||
self.config = model_config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some config refactoring so we can save our training arguments for reproducibility
with open(save_path / "config.json", "w", encoding="utf-8") as file: | ||
json.dump(config, file) | ||
json.dump(self.config, file) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We now dump all the arguments we used to create the model instead of just the ones we use at inference time
raise FileNotFoundError(f"Path {path} does not exist.") | ||
if not (load_path / "config.json").exists() or not (load_path / "model.joblib").exists(): | ||
raise FileNotFoundError("Model files not found in path.") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check if all the files exist before we load
model_config.pop("label", None) | ||
self.model = LinearRegression(**model_config) | ||
lr_config = {key: value for key, value in model_config.items() if key not in ["features", "label"]} | ||
self.model = LinearRegression(**lr_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copy instead of referencing config so we don't remove features and label from our actual stored config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Script to upload model. Still have to create a readme template for the models. Takes in a token as only specified users can push to project resilience repo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
For #82 . New method from_pretrained used to load predictors from huggingface. Special script used to save to huggingface. Only those with the access token can push to huggingface.