-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added ability to save and load from huggingfce #83
Changes from all commits
4ea5928
1429fc6
09e200f
f02d134
c075a55
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,15 +4,14 @@ | |
""" | ||
import copy | ||
import json | ||
import time | ||
from pathlib import Path | ||
import time | ||
|
||
|
||
import joblib | ||
import numpy as np | ||
import pandas as pd | ||
import joblib | ||
from tqdm import tqdm | ||
from sklearn.preprocessing import StandardScaler | ||
from tqdm import tqdm | ||
|
||
import torch | ||
from torch.utils.data import DataLoader | ||
|
@@ -112,15 +111,20 @@ def __init__(self, model_config: dict): | |
@classmethod | ||
def load(cls, path: str) -> "NeuralNetPredictor": | ||
""" | ||
Loads a model from a given folder containing a config.json, model.pt, and scaler.joblib. | ||
Loads a model from a given folder. | ||
:param path: path to folder containing model files. | ||
""" | ||
|
||
if isinstance(path, str): | ||
load_path = Path(path) | ||
else: | ||
load_path = path | ||
if not load_path.exists(): | ||
if not load_path.exists() or not load_path.is_dir(): | ||
raise FileNotFoundError(f"Path {path} does not exist.") | ||
if not (load_path / "config.json").exists() or \ | ||
not (load_path / "model.pt").exists() or \ | ||
not (load_path / "scaler.joblib").exists(): | ||
raise FileNotFoundError("Model files not found in path.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check to see all the files were downloaded properly before loading |
||
|
||
# Initialize model with config | ||
with open(load_path / "config.json", "r", encoding="utf-8") as file: | ||
|
@@ -135,7 +139,6 @@ def load(cls, path: str) -> "NeuralNetPredictor": | |
nnp.scaler = joblib.load(load_path / "scaler.joblib") | ||
return nnp | ||
|
||
|
||
def save(self, path: str): | ||
""" | ||
Saves model, config, and scaler into format for loading. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,9 @@ | |
Abstract class for predictors to inherit from. | ||
""" | ||
from abc import ABC, abstractmethod | ||
from pathlib import Path | ||
|
||
from huggingface_hub import snapshot_download | ||
import pandas as pd | ||
|
||
|
||
|
@@ -25,7 +27,7 @@ def fit(self, X_train: pd.DataFrame, y_train: pd.Series): | |
It is up to the model to decide which columns to use. | ||
:param y_train: series with target data | ||
""" | ||
|
||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def predict(self, X_test: pd.DataFrame) -> pd.DataFrame: | ||
|
@@ -36,19 +38,41 @@ def predict(self, X_test: pd.DataFrame) -> pd.DataFrame: | |
:param X_test: DataFrame with input data | ||
:return: DataFrame with predictions | ||
""" | ||
|
||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def save(self, path: str): | ||
""" | ||
Saves the model to a path. | ||
Saves the model to a local path. | ||
:param path: path to save the model | ||
""" | ||
raise NotImplementedError | ||
|
||
@classmethod | ||
@abstractmethod | ||
def load(cls, path: str) -> "Predictor": | ||
def from_pretrained(cls, path_or_url: str, **hf_args) -> "Predictor": | ||
""" | ||
Loads a model from a path or if it is not found, from a huggingface repo. | ||
:param path_or_url: path to the model or url to the huggingface repo. | ||
:param hf_args: arguments to pass to the snapshot_download function from huggingface. | ||
""" | ||
path = Path(path_or_url) | ||
if path.exists() and path.is_dir(): | ||
return cls.load(path) | ||
else: | ||
# TODO: Need a try except block to catch download errors | ||
url_path = path_or_url.replace("/", "--") | ||
local_dir = hf_args.get("local_dir", f"predictors/trained_models/{url_path}") | ||
|
||
if not Path(local_dir).exists() or not Path(local_dir).is_dir(): | ||
hf_args["local_dir"] = local_dir | ||
snapshot_download(repo_id=path_or_url, **hf_args) | ||
|
||
return cls.load(Path(local_dir)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Implementation of from_pretrained.
|
||
|
||
@classmethod | ||
def load(cls, path: Path) -> "Predictor": | ||
""" | ||
Loads a model from a path. | ||
Loads a model from the path on disk. | ||
:param path: path to the model | ||
""" | ||
raise NotImplementedError |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,10 +24,9 @@ def __init__(self, model_config: dict): | |
Model config contains the following: | ||
features: list of features to use for prediction (optional, defaults to all features) | ||
label: name of the label to predict (optional, defaults to passed label during fit) | ||
Any other parameters are passed to the model. | ||
""" | ||
self.features = model_config.get("features", None) | ||
self.label = model_config.get("label", None) | ||
|
||
self.config = model_config | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some config refactoring so we can save our training arguments for reproducibility |
||
self.model = None | ||
|
||
def save(self, path: str): | ||
|
@@ -41,21 +40,22 @@ def save(self, path: str): | |
else: | ||
save_path = path | ||
save_path.mkdir(parents=True, exist_ok=True) | ||
config = { | ||
"features": self.features, | ||
"label": self.label | ||
} | ||
with open(save_path / "config.json", "w", encoding="utf-8") as file: | ||
json.dump(config, file) | ||
json.dump(self.config, file) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We now dump all the arguments we used to create the model instead of just the ones we use at inference time |
||
joblib.dump(self.model, save_path / "model.joblib") | ||
|
||
@classmethod | ||
def load(cls, path) -> "SKLearnPredictor": | ||
""" | ||
Loads saved model and features from a folder. | ||
Loads saved model and config from a local folder. | ||
:param path: path to folder to load model files from. | ||
""" | ||
load_path = Path(path) | ||
if not load_path.exists() or not load_path.is_dir(): | ||
raise FileNotFoundError(f"Path {path} does not exist.") | ||
if not (load_path / "config.json").exists() or not (load_path / "model.joblib").exists(): | ||
raise FileNotFoundError("Model files not found in path.") | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check if all the files exist before we load |
||
with open(load_path / "config.json", "r", encoding="utf-8") as file: | ||
config = json.load(file) | ||
sklearn_predictor = cls(config) | ||
|
@@ -69,11 +69,11 @@ def fit(self, X_train: pd.DataFrame, y_train: pd.Series): | |
:param X_train: DataFrame with input data | ||
:param y_train: series with target data | ||
""" | ||
if self.features: | ||
X_train = X_train[self.features] | ||
if "features" in self.config: | ||
X_train = X_train[self.config["features"]] | ||
else: | ||
self.features = list(X_train.columns) | ||
self.label = y_train.name | ||
self.config["features"] = list(X_train.columns) | ||
self.config["label"] = y_train.name | ||
self.model.fit(X_train, y_train) | ||
|
||
def predict(self, X_test: pd.DataFrame) -> pd.DataFrame: | ||
|
@@ -83,21 +83,21 @@ def predict(self, X_test: pd.DataFrame) -> pd.DataFrame: | |
:param X_test: DataFrame with input data | ||
:return: properly labeled DataFrame with predictions and matching index. | ||
""" | ||
if self.features: | ||
X_test = X_test[self.features] | ||
X_test = X_test[self.config["features"]] | ||
y_pred = self.model.predict(X_test) | ||
return pd.DataFrame(y_pred, index=X_test.index, columns=[self.label]) | ||
return pd.DataFrame(y_pred, index=X_test.index, columns=[self.config["label"]]) | ||
|
||
class LinearRegressionPredictor(SKLearnPredictor): | ||
""" | ||
Simple linear regression predictor. | ||
See SKLearnPredictor for more details. | ||
""" | ||
def __init__(self, model_config: dict): | ||
if not model_config: | ||
model_config = {} | ||
super().__init__(model_config) | ||
model_config.pop("features", None) | ||
model_config.pop("label", None) | ||
self.model = LinearRegression(**model_config) | ||
lr_config = {key: value for key, value in model_config.items() if key not in ["features", "label"]} | ||
self.model = LinearRegression(**lr_config) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Copy instead of referencing config so we don't remove features and label from our actual stored config |
||
|
||
class RandomForestPredictor(SKLearnPredictor): | ||
""" | ||
|
@@ -107,9 +107,8 @@ class RandomForestPredictor(SKLearnPredictor): | |
""" | ||
def __init__(self, model_config: dict): | ||
super().__init__(model_config) | ||
model_config.pop("features", None) | ||
model_config.pop("label", None) | ||
self.model = RandomForestRegressor(**model_config) | ||
rf_config = {key: value for key, value in model_config.items() if key not in ["features", "label"]} | ||
self.model = RandomForestRegressor(**rf_config) | ||
|
||
def save(self, path: str, compression=0): | ||
""" | ||
|
@@ -118,11 +117,7 @@ def save(self, path: str, compression=0): | |
""" | ||
save_path = Path(path) | ||
save_path.mkdir(parents=True, exist_ok=True) | ||
config = { | ||
"features": self.features, | ||
"label": self.label | ||
} | ||
with open(save_path / "config.json", "w", encoding="utf-8") as file: | ||
json.dump(config, file) | ||
json.dump(self.config, file) | ||
joblib.dump(self.model, save_path / "model.joblib", compress=compression) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Script to upload model. Still have to create a readme template for the models. Takes in a token as only specified users can push to project resilience repo. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
""" | ||
Script to upload a model to huggingface hub. | ||
""" | ||
from argparse import ArgumentParser | ||
from pathlib import Path | ||
|
||
from huggingface_hub import HfApi | ||
|
||
def write_readme(model_path: str): | ||
""" | ||
Writes readme to model save path to upload. | ||
TODO: Need to add more info to the readme and make it a proper template. | ||
""" | ||
model_path = Path(model_path) | ||
with open(model_path / "README.md", "w", encoding="utf-8") as f: | ||
f.write("This is a demo model created for project resilience") | ||
|
||
def upload_to_repo(model_path: str, repo_id: str, token: str=None): | ||
""" | ||
Uses huggingface hub to upload the model to a repo. | ||
""" | ||
model_path = Path(model_path) | ||
api = HfApi() | ||
api.create_repo( | ||
repo_id=repo_id, | ||
repo_type="model", | ||
exist_ok=True, | ||
token=token | ||
) | ||
|
||
api.upload_folder( | ||
folder_path=model_path, | ||
repo_id=repo_id, | ||
repo_type="model", | ||
token=token | ||
) | ||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser() | ||
parser.add_argument("--model_path", type=str, required=True) | ||
parser.add_argument("--repo_id", type=str, required=True) | ||
parser.add_argument("--token", type=str, required=False) | ||
args = parser.parse_args() | ||
|
||
write_readme(args.model_path) | ||
upload_args = {"model_path": args.model_path, "repo_id": args.repo_id} | ||
if args.token: | ||
upload_args["token"] = args.token | ||
upload_to_repo(**upload_args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Default cache dir for our huggingface models