Skip to content

Commit

Permalink
backup
Browse files Browse the repository at this point in the history
  • Loading branch information
rileydrizzy committed Jan 4, 2024
1 parent b256c0b commit 87823c8
Show file tree
Hide file tree
Showing 12 changed files with 167 additions and 160 deletions.
5 changes: 2 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,8 @@ mlruns
logs
Tensorbord_logs

#
model_checkpoints
model_artifact
#model_artifact
artifacts

# Environment variables and keys
environ_variables.sh
Binary file not shown.
Binary file removed development/clean_dev.gzip
Binary file not shown.
Binary file removed development/dev.gzip
Binary file not shown.
125 changes: 49 additions & 76 deletions development/new_dev.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions environ_variables_temp.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# replace placeholders and rename this file to `environ_variables.sh`
#!/bin/
#./environ_variables.sh
# Set Kaggle Username and Key
export KAGGLE_USERNAME=username
export KAGGLE_KEY=xxxxxxxxxxxxxx

# Set MLFLOW_TRACKING_USERNAME and MLFLOW_TRACKING_PASSWORD
export MLFLOW_TRACKING_URI=xxxxxxxxxxxxxx
export MLFLOW_TRACKING_USERNAME=xxxxxxxxxxxxxx
export MLFLOW_TRACKING_PASSWORD=xxxxxxxxxxxxxx
8 changes: 7 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ hydra-core==1.3.2
identify==2.5.33
idna==3.6
importlib-metadata==7.0.1
iniconfig==2.0.0
iterative-telemetry==0.0.8
itsdangerous==2.1.2
Jinja2==3.1.2
Expand Down Expand Up @@ -116,9 +117,12 @@ pathspec==0.12.1
pathvalidate==3.0.0
Pillow==10.1.0
platformdirs==3.11.0
pluggy==1.3.0
polars==0.20.2
pre-commit==3.6.0
protobuf>=4.20.1
prompt-toolkit==3.0.43
protobuf==4.23.4
psutil==5.9.7
pyarrow==14.0.2
pyasn1==0.5.1
pyasn1-modules==0.3.0
Expand All @@ -131,6 +135,7 @@ Pygments==2.17.2
pygtrie==2.5.0
PyJWT==2.8.0
pyparsing==3.1.1
pytest==7.4.4
python-dateutil==2.8.2
python-slugify==8.0.1
pytz==2023.3.post1
Expand Down Expand Up @@ -178,6 +183,7 @@ urllib3==1.26.18
vine==5.1.0
virtualenv==20.25.0
voluptuous==0.14.1
wcwidth==0.2.12
webencodings==0.5.1
websocket-client==1.7.0
Werkzeug==3.0.1
Expand Down
5 changes: 4 additions & 1 deletion src/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,7 @@ params:
cm_threshold: 0.5

model_name:
"1DCNN"
"11DCNN"

save_to_mlflow:
True
77 changes: 47 additions & 30 deletions src/evalute.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,71 @@
"""doc
"""Evaluation Script for Trained Models
This script performs evaluation on a trained TensorFlow model using a test dataset.\
It retrieves the model either from local storage or MLflow based on the configuration.
Functions:
- retrieve_trained_model(model_name: str, from_mlflow: bool = False) -> tf.keras.Model:\
Retrieves a trained model either from local storage or MLflow.
- main(cfg: DictConfig): Main function for the evaluation script.
Parameters:
- `cfg` (DictConfig): Configuration parameters loaded using Hydra.
Notes:
- The script evaluates the Precision-Recall Curve and Confusion Matrix on a trained model\
using a test dataset.
- It uses MLflow for model tracking if specified in the configuration.
Example:
```bash
python src/evalute.py
```
"""

import hydra
import mlflow
import tensorflow as tf
from omegaconf import DictConfig


from dataset_loader import get_dataset
from utils.common_utils import (
set_seed,
plot_confusion_matrix,
plot_precision_recall_curve,
)
from utils.common_utils import (plot_confusion_matrix,
plot_precision_recall_curve, set_seed)
from utils.logging import logger

# Directory to save plots.
plots_dir = "docs/plots/evaluation"


def retrive_trained_model(model_name, from_mlflow: False):
"""_summary_
def retrieve_trained_model(model_name, from_mlflow=False):
"""
Retrieves a trained model either from local storage or MLflow.
Parameters
----------
model_name :str
_description_
Parameters:
- `model_name` (str): Name of the model.
- `from_mlflow` (bool): Flag indicating whether to retrieve the model from MLflow.\
Default is False.
Returns
-------
_type_
_description_
Returns:
- `tf.keras.Model`: Trained TensorFlow model.
"""
if from_mlflow:
client = mlflow.MlflowClient()
version = client.get_latest_versions(name=model_name)[0].version
model_uri = f"models:/{model_name}/{version}"
model = mlflow.keras.load_model(model_uri)
return model
else:
model_checkpoint = f"artifacts/{model_name}/model_checkpoints"
model = tf.keras.models.load_model(model_checkpoint)
return model


@hydra.main(config_name="config", config_path="config", version_base="1.2")
def main(cfg: DictConfig):
"""
Parameters
----------
cfg : DictConfig
_description_
Main function for the evaluation script.
Parameters:
- `cfg` (DictConfig): Configuration parameters loaded using Hydra.
"""
try:
logger.info("Commencing evaluation process with test dataset")
logger.info("Commencing evaluation process with the test dataset")
set_seed()

f1_score_metrics = tf.keras.metrics.F1Score(
Expand All @@ -63,30 +79,31 @@ def main(cfg: DictConfig):
)

logger.info(f"Retrieving the model: {cfg.model_name}")
model = retrive_trained_model(cfg.model_name)
model = retrieve_trained_model(cfg.model_name, from_mlflow=False)

logger.info(
f"Starting evaluation of Precision-Recall Curve on trained {cfg.model_name}"
)
plot_precision_recall_curve(
model, eval_dataset=test_data, save_path=f"{plots_dir}/{cfg.model_name},"
model, model_name=cfg.model_name, eval_dataset=test_data, plot_label="Test"
)
logger.info("Precision-Recall Curve evaluation completed.")

logger.info(
f"Starting evaluation of Confusion Matrix on trained {cfg.model_name}"
f"Starting evaluation of the Confusion Matrix on trained {cfg.model_name}"
)
plot_confusion_matrix(
model,
model_name=cfg.model_name,
eval_dataset=test_data,
threshold=cfg.params.cm_threshold,
save_path=f"",
plot_label="Test",
)
logger.info("Confusion Matrix evaluation completed.")

except ValueError:
logger.error(
"Plotting of Confusion Matrix and Precision-Recall Curve failed due to empty data"
"Plotting of the Confusion Matrix and Precision-Recall Curve failed due to empty data"
)
except Exception as error:
logger.exception(f"Evaluation failed due to -> {error}.")
Expand Down
57 changes: 26 additions & 31 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
python src/main.py
```
"""
# TODO Add F1_Score


import hydra
Expand All @@ -33,32 +34,11 @@
# Importing functions and classes from other modules
from dataset_loader import get_dataset, get_vectorization_layer
from models.model_loader import ModelLoader
from utils.common_utils import (
get_device_strategy,
set_mlflow_tracking,
set_seed,
tensorboard_dir,
plot_confusion_matrix,
plot_precision_recall_curve,
)
from utils.common_utils import (get_device_strategy, plot_confusion_matrix,
plot_precision_recall_curve,
set_mlflow_tracking, set_seed, tensorboard_dir)
from utils.logging import logger

# Callbacks for model training
checkpoints_cb = tf.keras.callbacks.ModelCheckpoint(
"model_checkpoints",
save_best_only=True,
)
early_stopping_cb = tf.keras.callbacks.EarlyStopping(
patience=10, restore_best_weights=True
)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
monitor="val_loss", factor=0.2, patience=5, verbose=1
)
callbacks_list = [checkpoints_cb, early_stopping_cb, reduce_lr]

# Directory to save plots.
PLOTS_DIR = "docs/plots/training"


@hydra.main(config_name="config", config_path="config", version_base="1.2")
def main(cfg: DictConfig):
Expand All @@ -78,6 +58,19 @@ def main(cfg: DictConfig):

# Set up MLflow tracking for the experiment
experiment_id = set_mlflow_tracking(cfg.model_name)
# Callbacks for model training
checkpoint_path = f"artifacts/{cfg.model_name}/model_checkpoints"

checkpoints_cb = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, save_best_only=True
)
early_stopping_cb = tf.keras.callbacks.EarlyStopping(
patience=10, restore_best_weights=True
)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
monitor="val_loss", factor=0.5, patience=5, verbose=1
)
callbacks_list = [checkpoints_cb, early_stopping_cb, reduce_lr]

# Set up TensorBoard logging directory
model_tensorb_dir = tensorboard_dir(cfg.model_name)
Expand Down Expand Up @@ -129,7 +122,6 @@ def main(cfg: DictConfig):
model.compile(
loss=loss_func,
optimizer=optim,
metrics=[f1_score_metrics()],
)
logger.info(
f" Training {cfg.model_name} for {cfg.params.total_epochs} epochs"
Expand All @@ -155,8 +147,10 @@ def main(cfg: DictConfig):
)
plot_precision_recall_curve(
model,
model_name=cfg.model_name,
eval_dataset=valid_data,
save_path=f"{PLOTS_DIR}/{cfg.model_name}_PR.png",
save_path=True,
save_to_mlflow=cfg.save_to_mlflow,
)
logger.info("Precision-Recall Curve evaluation completed.")

Expand All @@ -165,15 +159,16 @@ def main(cfg: DictConfig):
)
plot_confusion_matrix(
model,
model_name=cfg.model_name,
eval_dataset=valid_data,
threshold=cfg.params.cm_threshold,
save_path=f"{PLOTS_DIR}/{cfg.model_name}_CM.png",
save_path=True,
save_to_mlflow=cfg.save_to_mlflow,
)
logger.info("Confusion Matrix evaluation completed.")
except ValueError:
logger.critical(
"Plotting of Confusion Matrix and Precision-Recall Curve failed due to empty data"
)
logger.success("All jobs completed")
except mlflow.exceptions.MlflowException:
logger.exception("Erorr due to Mlflow login details")

except Exception as error:
logger.exception(f"Training failed due to -> {error}.")
Expand Down
6 changes: 3 additions & 3 deletions src/models/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
"""

from models import baseline_cnn #, temporal_cn
from models import baseline_cnn # , temporal_cn


class ModelLoader:
"""Class for Loading TensorFlow Models"""

def __init__(self):
self.models = {"1DCNN": baseline_cnn.build_model}
self.models = {"11DCNN": baseline_cnn.build_model}

def get_model(self, model_name: str) -> object:
"""Build and Retrieve a TensorFlow Model Instance.
Expand All @@ -48,7 +48,7 @@ def get_model(self, model_name: str) -> object:
- ValueError: If the specified model is not in the model list.
"""
if model_name in self.models:
return self.models[model_name]()
return self.models[model_name]
raise ValueError(
f"Model '{model_name}' is not in the model list. Available models:\
{list(self.models.keys())}"
Expand Down
Loading

0 comments on commit 87823c8

Please sign in to comment.