diff --git a/src/create_train_test.py b/src/create_train_test.py index 84a45f6..495c5c0 100644 --- a/src/create_train_test.py +++ b/src/create_train_test.py @@ -2,6 +2,7 @@ import os import polars as pl + from utils.functions import load_pickle parser = argparse.ArgumentParser(description="Create train/val/test split.") diff --git a/src/datasets.py b/src/datasets.py index 5e365e4..14c4f3f 100644 --- a/src/datasets.py +++ b/src/datasets.py @@ -4,6 +4,7 @@ import torch from sklearn.model_selection import train_test_split from torch.utils.data import Dataset + from utils.functions import load_pickle, preview_data diff --git a/src/evaluate.py b/src/evaluate.py index f3b4ac8..06d242c 100644 --- a/src/evaluate.py +++ b/src/evaluate.py @@ -7,7 +7,6 @@ import polars as pl import shap import toml -from datasets import CollateTimeSeries, MIMIC4Dataset from fairlearn.metrics import ( MetricFrame, count, @@ -19,7 +18,6 @@ selection_rate, ) from lightning.pytorch import Trainer -from models import MMModel from sklearn.metrics import ( accuracy_score, average_precision_score, @@ -28,6 +26,9 @@ ) from torch import concat from torch.utils.data import DataLoader + +from datasets import CollateTimeSeries, MIMIC4Dataset +from models import MMModel from utils.functions import load_pickle, read_from_txt from utils.preprocessing import transform_race diff --git a/src/postprocess.py b/src/postprocess.py index ce2712d..f47ef24 100644 --- a/src/postprocess.py +++ b/src/postprocess.py @@ -5,13 +5,14 @@ import numpy as np import polars as pl import toml -from datasets import MIMIC4Dataset from fairlearn.postprocessing import ThresholdOptimizer, plot_threshold_optimizer from sklearn.metrics import ( accuracy_score, balanced_accuracy_score, confusion_matrix, ) + +from datasets import MIMIC4Dataset from utils.functions import load_pickle, read_from_txt if __name__ == "__main__": diff --git a/src/prepare_data.py b/src/prepare_data.py index 5558a05..86ee39c 100644 --- a/src/prepare_data.py +++ b/src/prepare_data.py @@ -6,6 +6,7 @@ import polars as pl from tqdm import tqdm + from utils.functions import scale_numeric_features from utils.preprocessing import ( add_time_elapsed_to_events, diff --git a/src/train.py b/src/train.py index ca03dc5..0cdd021 100644 --- a/src/train.py +++ b/src/train.py @@ -2,15 +2,16 @@ import lightning as L import toml -from datasets import CollateFn, CollateTimeSeries, MIMIC4Dataset from lightning.pytorch.callbacks import ( EarlyStopping, LearningRateMonitor, ModelCheckpoint, ) from lightning.pytorch.loggers import CSVLogger, WandbLogger -from models import MMModel from torch.utils.data import DataLoader + +from datasets import CollateFn, CollateTimeSeries, MIMIC4Dataset +from models import MMModel from utils.functions import read_from_txt if __name__ == "__main__": diff --git a/src/train_rf.py b/src/train_rf.py index 896d0b0..8440d6a 100644 --- a/src/train_rf.py +++ b/src/train_rf.py @@ -4,7 +4,6 @@ import numpy as np import toml -from datasets import MIMIC4Dataset from sklearn.ensemble import RandomForestClassifier from sklearn.metrics import ( accuracy_score, @@ -13,6 +12,8 @@ roc_auc_score, ) from sklearn.model_selection import GridSearchCV + +from datasets import MIMIC4Dataset from utils.functions import read_from_txt if __name__ == "__main__":