From 64e2c3943916b839ad9d45c3df63000513580b85 Mon Sep 17 00:00:00 2001 From: xin-huang Date: Tue, 26 Sep 2023 16:24:42 +0200 Subject: [PATCH] Update models.py --- sstar/models.py | 53 ++++++++++++++++++++++++++++++++----------------- 1 file changed, 35 insertions(+), 18 deletions(-) diff --git a/sstar/models.py b/sstar/models.py index b3fb78a..6975081 100644 --- a/sstar/models.py +++ b/sstar/models.py @@ -17,36 +17,53 @@ import statsmodels.api as sm import statsmodels.formula.api as smf import numpy as np +from abc import ABC, abstractmethod -def train_logistic_regression(train_df, model_file): +class Model(ABC): """ - Description: - Function for training of the statsmodels logistic classification. - - Arguments: - train_df pandas.DataFrame: Training data - save_filename str: filename for output model """ - sm_data_exog = train_df.copy() - sm_data_exog.drop(['label'], axis=1, inplace=True) - sm_data_exog = sm.add_constant(sm_data_exog, prepend=False) + @abstractmethod + def train(self): + pass + + + @abstractmethod + def infer(self): + pass - sm_data_endog = train_df['label'] - glm_binom = sm.GLM(sm_data_endog.astype(int), sm_data_exog.astype(float),family=sm.families.Binomial()) - result = glm_binom.fit() +class LogisticRegression(Model): + """ + """ + def train(): + """ + Description: + Function for training of the statsmodels logistic classification. - result.save(model_file) + Arguments: + train_df pandas.DataFrame: Training data + save_filename str: filename for output model + """ + sm_data_exog = train_df.copy() + sm_data_exog.drop(['label'], axis=1, inplace=True) + sm_data_exog = sm.add_constant(sm_data_exog, prepend=False) + sm_data_endog = train_df['label'] -def infer_logistic_regression(test_df, model_file, output_file): - pass + glm_binom = sm.GLM(sm_data_endog.astype(int), sm_data_exog.astype(float),family=sm.families.Binomial()) + result = glm_binom.fit() + result.save(model_file) -def train_sstar(): + +class ExtraTrees(Model): + """ + """ pass -def train_extra_trees(): +class Sstar(Model): + """ + """ pass