Skip to content

Commit

Permalink
Added config documentation and return type for classmethods
Browse files Browse the repository at this point in the history
  • Loading branch information
danyoungday committed May 14, 2024
1 parent 15a059d commit c4bde78
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
16 changes: 15 additions & 1 deletion use_cases/eluc/predictors/neural_network/neural_net_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,20 @@ class NeuralNetPredictor(Predictor):
Data is automatically standardized and the scaler is saved with the model.
"""
def __init__(self, model_config: dict):
"""
Model config should contain the following:
features: list of features to use in the model (optional, defaults to all features)
label: name of the label column (optional, defaults to passed label in fit)
hidden_sizes: list of hidden layer sizes
linear_skip: whether to concatenate input to hidden layer output
dropout: dropout probability
device: device to run the model on
epochs: number of epochs to train for
batch_size: batch size for training
optim_params: dictionary of parameters to pass to the optimizer
train_pct: percentage of training data to use
step_lr_params: dictionary of parameters to pass to the step learning rate scheduler
"""

self.features = model_config.get("features", None)
self.label = model_config.get("label", None)
Expand All @@ -96,7 +110,7 @@ def __init__(self, model_config: dict):
self.scaler = StandardScaler()

@classmethod
def load(cls, path: str):
def load(cls, path: str) -> "NeuralNetPredictor":
"""
Loads a model from a given folder containing a config.json, model.pt, and scaler.joblib.
:param path: path to folder containing model files.
Expand Down
2 changes: 1 addition & 1 deletion use_cases/eluc/predictors/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def save(self, path: str):

@classmethod
@abstractmethod
def load(cls, path: str):
def load(cls, path: str) -> "Predictor":
"""
Loads a model from a path.
:param path: path to the model
Expand Down
7 changes: 6 additions & 1 deletion use_cases/eluc/predictors/sklearn/sklearn_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ class SKLearnPredictor(Predictor, ABC):
Keeps track of features fit on and label to predict.
"""
def __init__(self, model_config: dict):
"""
Model config contains the following:
features: list of features to use for prediction (optional, defaults to all features)
label: name of the label to predict (optional, defaults to passed label during fit)
"""
self.features = model_config.get("features", None)
self.label = model_config.get("label", None)

Expand All @@ -45,7 +50,7 @@ def save(self, path: str):
joblib.dump(self.model, save_path / "model.joblib")

@classmethod
def load(cls, path):
def load(cls, path) -> "SKLearnPredictor":
"""
Loads saved model and features from a folder.
:param path: path to folder to load model files from.
Expand Down

0 comments on commit c4bde78

Please sign in to comment.