Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
xin-huang committed Jan 8, 2024
1 parent 908d921 commit 9359ec3
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 10 deletions.
3 changes: 1 addition & 2 deletions sstar/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _run_training(args):

def _run_inference(args):
from sstar.infer import infer
infer(feature_file=args.features, model_file=args.model_file, cutoff=args.cutoff,
infer(feature_file=args.features, model_file=args.model_file,
prediction_dir=args.prediction_dir, prediction_prefix=args.prediction_prefix, algorithm=args.model)


Expand Down Expand Up @@ -249,7 +249,6 @@ def _s_star_cli_parser():
parser.add_argument('--features', type=str, required=True, help="Name of the file storing input features.")
parser.add_argument('--model-file', type=str, required=True, help="Name of the file storing the trained model.", dest='model_file')
parser.add_argument('--model', type=str, default=None, help="Name of the statistical/machine learning model for the training. Implemented models: extra_trees, logistic_regression, sstar.")
parser.add_argument('--cutoff', type=float, default=0.5, help="Probability cutoff for classifying a fragment as introgressed. Fragments with probabilities above this threshold will be considered introgressed. Default: 0.5.")
parser.add_argument('--prediction-prefix', type=str, required=True, help="Prefix of the prediction file name.", dest='prediction_prefix')
parser.add_argument('--prediction-dir', type=str, required=True, help="Directory of the prediction files.", dest='prediction_dir')
parser.set_defaults(runner=_run_inference)
Expand Down
35 changes: 30 additions & 5 deletions sstar/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ def evaluate(truth_tract_file, inferred_tract_file, output):
truth_tracts_samples = truth_tracts['sample'].unique()
inferred_tracts_samples = inferred_tracts['sample'].unique()

res = pd.DataFrame(columns=['sample', 'precision', 'recall'])
res = pd.DataFrame(columns=['sample', 'precision', 'recall', 'true_positive', 'inferred_tracts_length', 'truth_tracts_length'])

sum_ntruth_tracts = 0
sum_ninferred_tracts = 0
sum_ntrue_positives = 0

for s in np.intersect1d(truth_tracts_samples, inferred_tracts_samples):
ind_truth_tracts = truth_tracts[truth_tracts['sample'] == s][['chrom', 'start', 'end']]
Expand All @@ -47,17 +51,38 @@ def evaluate(truth_tract_file, inferred_tract_file, output):

precision, recall = cal_pr(ntruth_tracts, ninferred_tracts, ntrue_positives)

res.loc[len(res.index)] = [s, precision, recall]
res.loc[len(res.index)] = [s, precision, recall, ntrue_positives, ninferred_tracts, ntruth_tracts]

sum_ntruth_tracts += ntruth_tracts
sum_ninferred_tracts += ninferred_tracts
sum_ntrue_positives += ntrue_positives

for s in np.setdiff1d(truth_tracts_samples, inferred_tracts_samples):
# ninferred_tracts = 0
res.loc[len(res.index)] = [s, np.nan, 0]
ind_truth_tracts = truth_tracts[truth_tracts['sample'] == s][['chrom', 'start', 'end']]
ind_truth_tracts = pybedtools.BedTool.from_dataframe(ind_truth_tracts).sort().merge()
ntruth_tracts = sum([x.stop - x.start for x in (ind_truth_tracts)])

res.loc[len(res.index)] = [s, 'NA', 0, 0, 0, ntruth_tracts]

sum_ntruth_tracts += ntruth_tracts

for s in np.setdiff1d(inferred_tracts_samples, truth_tracts_samples):
# ntruth_tracts = 0
res.loc[len(res.index)] = [s, 0, np.nan]
ind_inferred_tracts = inferred_tracts[inferred_tracts['sample'] == s][['chrom', 'start', 'end']]
ind_inferred_tracts = pybedtools.BedTool.from_dataframe(ind_inferred_tracts).sort().merge()
ninferred_tracts = sum([x.stop - x.start for x in (ind_inferred_tracts)])

res.loc[len(res.index)] = [s, 0, 'NA', 0, ninferred_tracts, 0]

sum_ninferred_tracts += ninferred_tracts

res = res.sort_values(by=['sample'])

total_precision, total_recall = cal_pr(sum_ntruth_tracts, sum_ninferred_tracts, sum_ntrue_positives)
res.loc[len(res.index)] = ['Summary', total_precision, total_recall, sum_ntrue_positives, sum_ninferred_tracts, sum_ntruth_tracts]

res.sort_values(by=['sample']).to_csv(output, sep="\t", index=False)
res.to_csv(output, sep="\t", index=False)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions sstar/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
pd.options.mode.chained_assignment = None


def infer(feature_file, model_file, prediction_dir, prediction_prefix, cutoff, algorithm=None):
def infer(feature_file, model_file, prediction_dir, prediction_prefix, algorithm=None):
"""
"""
with open(model_file, 'rb') as f:
Expand Down Expand Up @@ -57,4 +57,4 @@ def infer(feature_file, model_file, prediction_dir, prediction_prefix, cutoff, a


if __name__ == '__main__':
infer(feature_file="/scratch/admixlab/xinhuang/projects/sstar2-analysis-dev/results/test_data/ArchIE_3D19/nref_50/ntgt_50/956714/0/sim.test.0.archie.features", model_file="/scratch/admixlab/xinhuang/projects/sstar2-analysis-dev/tmp/archie.imbalanced.logistic_regression.model", prediction_dir="./sstar/test", prediction_prefix="test.imbalanced.test", cutoff=0.5, algorithm="logistic_regression")
infer(feature_file="/scratch/admixlab/xinhuang/projects/sstar2-analysis-dev/results/test_data/ArchIE_3D19/nref_50/ntgt_50/956714/0/sim.test.0.archie.features", model_file="/scratch/admixlab/xinhuang/projects/sstar2-analysis-dev/tmp/archie.imbalanced.logistic_regression.model", prediction_dir="./sstar/test", prediction_prefix="test.imbalanced.test", algorithm="logistic_regression")
2 changes: 1 addition & 1 deletion sstar/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,4 +230,4 @@ def _add_label(row, intro_prop, not_intro_prop):
simulate(demo_model_file="./examples/models/ArchIE_3D19.yaml", nrep=1, nref=50, ntgt=50,
ref_id='Ref', tgt_id='Tgt', src_id='Ghost', ploidy=2, seq_len=50000, mut_rate=1.25e-8, rec_rate=1e-8, thread=2,
feature_config=None, is_phased=True, intro_prop=0.7, not_intro_prop=0.3, keep_sim_data=True,
output_prefix='test', output_dir='./sstar/test7', seed=913)
output_prefix='test', output_dir='./sstar/test7', seed=555)

0 comments on commit 9359ec3

Please sign in to comment.