Skip to content

Commit

Permalink
Update train.py and model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
xin-huang committed Sep 22, 2023
1 parent bd9463a commit 961facd
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
17 changes: 12 additions & 5 deletions sstar/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np


def train_logistic_regression(df, model_file):
def train_logistic_regression(train_df, model_file):
"""
Description:
Function for training of the statsmodels logistic classification.
Expand All @@ -29,13 +29,20 @@ def train_logistic_regression(df, model_file):
save_filename str: filename for output model
"""
sm_data_exog = train_df.copy()
sm_data_exog.drop(["label"], axis=1, inplace=True)
sm_data_exog.replace(np.nan, 0, inplace=True)
sm_data_exog.drop(['label'], axis=1, inplace=True)
sm_data_exog = sm.add_constant(sm_data_exog, prepend=False)

sm_data_endog = train_df["label"]
sm_data_endog = train_df['label']

glm_binom = sm.GLM(sm_data_endog.astype(int), sm_data_exog.astype(float),family=sm.families.Binomial())
result = glm_binom.fit()

result.save(save_filename)
result.save(model_file)


def train_sstar():
pass


def train_extra_trees():
pass
9 changes: 7 additions & 2 deletions sstar/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pandas as pd
from multiprocessing import Process, Queue
from sstar.preprocess import process_data
from sstar.models import train_logistic_regression


def train(demo_model_file, nrep, nref, ntgt, ref_id, tgt_id, src_id, seq_len, mut_rate, rec_rate, thread, output_prefix, output_dir, algorithm=None, seed=None):
Expand Down Expand Up @@ -57,8 +58,12 @@ def _train_logistic_regression(nrep, thread, output_prefix, output_dir, seq_len,
df = pd.read_csv(feature_file, sep="\t")
feature_df = pd.concat([feature_df, df])

#feature_df = feature_df.drop(columns=['chrom', 'start', 'end', 'sample', 'hap'])
feature_df.to_csv(output_dir + '/' + output_prefix + '.all.features', sep="\t", index=False)
all_feature_file = output_dir + '/' + output_prefix + '.all.features'
model_file = output_dir + '/' + output_prefix + '.logistic.regression.model'

feature_df.to_csv(all_feature_file, sep="\t", index=False)
feature_df = feature_df.drop(columns=['chrom', 'start', 'end', 'sample', 'hap'])
train_logistic_regression(feature_df, model_file)


def _train_extra_trees():
Expand Down

0 comments on commit 961facd

Please sign in to comment.