From 961facd23a71b15338b739da12df5199a98ded9e Mon Sep 17 00:00:00 2001 From: xin-huang Date: Fri, 22 Sep 2023 21:07:07 +0200 Subject: [PATCH] Update train.py and model.py --- sstar/models.py | 17 ++++++++++++----- sstar/train.py | 9 +++++++-- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/sstar/models.py b/sstar/models.py index ac78f95..7c1209f 100644 --- a/sstar/models.py +++ b/sstar/models.py @@ -19,7 +19,7 @@ import numpy as np -def train_logistic_regression(df, model_file): +def train_logistic_regression(train_df, model_file): """ Description: Function for training of the statsmodels logistic classification. @@ -29,13 +29,20 @@ def train_logistic_regression(df, model_file): save_filename str: filename for output model """ sm_data_exog = train_df.copy() - sm_data_exog.drop(["label"], axis=1, inplace=True) - sm_data_exog.replace(np.nan, 0, inplace=True) + sm_data_exog.drop(['label'], axis=1, inplace=True) sm_data_exog = sm.add_constant(sm_data_exog, prepend=False) - sm_data_endog = train_df["label"] + sm_data_endog = train_df['label'] glm_binom = sm.GLM(sm_data_endog.astype(int), sm_data_exog.astype(float),family=sm.families.Binomial()) result = glm_binom.fit() - result.save(save_filename) + result.save(model_file) + + +def train_sstar(): + pass + + +def train_extra_trees(): + pass diff --git a/sstar/train.py b/sstar/train.py index 62d332b..a51e68f 100644 --- a/sstar/train.py +++ b/sstar/train.py @@ -19,6 +19,7 @@ import pandas as pd from multiprocessing import Process, Queue from sstar.preprocess import process_data +from sstar.models import train_logistic_regression def train(demo_model_file, nrep, nref, ntgt, ref_id, tgt_id, src_id, seq_len, mut_rate, rec_rate, thread, output_prefix, output_dir, algorithm=None, seed=None): @@ -57,8 +58,12 @@ def _train_logistic_regression(nrep, thread, output_prefix, output_dir, seq_len, df = pd.read_csv(feature_file, sep="\t") feature_df = pd.concat([feature_df, df]) - #feature_df = feature_df.drop(columns=['chrom', 'start', 'end', 'sample', 'hap']) - feature_df.to_csv(output_dir + '/' + output_prefix + '.all.features', sep="\t", index=False) + all_feature_file = output_dir + '/' + output_prefix + '.all.features' + model_file = output_dir + '/' + output_prefix + '.logistic.regression.model' + + feature_df.to_csv(all_feature_file, sep="\t", index=False) + feature_df = feature_df.drop(columns=['chrom', 'start', 'end', 'sample', 'hap']) + train_logistic_regression(feature_df, model_file) def _train_extra_trees():