From 9b00f5141b4a531d02c51060b0a47b5dbbdea1f3 Mon Sep 17 00:00:00 2001 From: Kushagra Taneja Date: Tue, 29 Oct 2024 19:01:23 +0530 Subject: [PATCH] added model_details function --- .../notebooks/Business_forecasting.py | 38 +++++++++++++++++- .../predict.py | 24 ++++++++++- .../saved_models/evaluation_results.pkl | Bin 0 -> 93 bytes page_handler.py | 17 ++++---- 4 files changed, 68 insertions(+), 11 deletions(-) create mode 100644 models/business_performance_forecasting/saved_models/evaluation_results.pkl diff --git a/models/business_performance_forecasting/notebooks/Business_forecasting.py b/models/business_performance_forecasting/notebooks/Business_forecasting.py index ae052b4c..36bb92c5 100644 --- a/models/business_performance_forecasting/notebooks/Business_forecasting.py +++ b/models/business_performance_forecasting/notebooks/Business_forecasting.py @@ -2,7 +2,8 @@ import pandas as pd import pickle import os - +import matplotlib.pyplot as plt +from sklearn.metrics import r2_score # Load the data df = pd.read_csv('50_Startups.csv') X = df.iloc[:, :-1].values @@ -41,3 +42,38 @@ pickle.dump(ct, scaler_file) print("Model and preprocessing objects saved successfully!") + +def save_evaluation_to_pickle(train_X, train_Y, test_X, test_Y, output_file="evaluation_results.pkl"): + # Calculate R^2 score + train_r2 = r2_score(train_Y, model.predict(train_X)) + test_r2 = r2_score(test_Y, y_pred) + + # Create plot + fig, ax = plt.subplots(figsize=(10, 6)) + ax.scatter(test_Y, y_pred, alpha=0.6, color='blue', label='Predicted') + ax.plot([test_Y.min(), test_Y.max()], [test_Y.min(), test_Y.max()], 'r--', label='Perfect Prediction') + ax.set_xlabel("Actual") + ax.set_ylabel("Predicted") + ax.set_title("Actual vs Predicted Values (Test Set)") + ax.legend() + ax.grid(True) + + # Save the plot as a PNG file + plot_file = "actual_vs_predicted.png" + fig.savefig(plot_file) + + # Package results + results = { + "Train_R2": train_r2, + "Test_R2": test_r2, + "plot_file": plot_file # Save the plot file path + } + + # Save results to a pickle file + with open(output_file, "wb") as f: + pickle.dump(results, f) + + print(f"Evaluation and plot data saved to {output_file}") + print(f"Plot saved as {plot_file}") +# Run this function once to generate the evaluation file +save_evaluation_to_pickle(X_train, y_train, X_test, y_test) \ No newline at end of file diff --git a/models/business_performance_forecasting/predict.py b/models/business_performance_forecasting/predict.py index 7298c9be..44e53e76 100644 --- a/models/business_performance_forecasting/predict.py +++ b/models/business_performance_forecasting/predict.py @@ -1,10 +1,14 @@ -# import os +import os import numpy as np +import pandas as pd +import seaborn as sns +import matplotlib.pyplot as plt +import pickle from models.business_performance_forecasting.model import load_model_and_scaler # Import the function from model.py # Define the prediction function def get_prediction(RnD_Spend, Administration, Marketing_Spend, State): - # Load the model and scaler + # Load the model and scalers model, scaler = load_model_and_scaler() # Prepare input features as a NumPy array input_data = np.array([[RnD_Spend, Administration, Marketing_Spend, State]]) @@ -18,3 +22,19 @@ def get_prediction(RnD_Spend, Administration, Marketing_Spend, State): return prediction[0] # Return the predicted profit + +class ModelEvaluation: + def __init__(self): + metrics_file= os.path.join(os.path.dirname(__file__), 'saved_models', 'evaluation_results.pkl') + # Load evaluation metrics from a pickle file + with open(metrics_file, "rb") as f: + self.metrics = pickle.load(f) + print("Loaded metrics:", self.metrics) + def evaluate(self): + metrics = self.metrics + return metrics, None, None, None + +def model_details(): + evaluator = ModelEvaluation() + return evaluator + diff --git a/models/business_performance_forecasting/saved_models/evaluation_results.pkl b/models/business_performance_forecasting/saved_models/evaluation_results.pkl new file mode 100644 index 0000000000000000000000000000000000000000..a2f0a591d8223c68271e25360be22b7e6037b1b1 GIT binary patch literal 93 zcmZo*nHt0Z0ku;!dN@Lg5;ODSgN&xQ+rLYH`9L*5pocvqwYUT#^!BA(=;UWjJ)8wO p`6cmbnK`Lbdc+fxOG*=S;>(KT3yM-xGLuVEQ}hb*(x;T_0RREHBUAtY literal 0 HcmV?d00001 diff --git a/page_handler.py b/page_handler.py index 75836c50..6f070670 100644 --- a/page_handler.py +++ b/page_handler.py @@ -85,11 +85,12 @@ def render_model_details(self, model_module,tab): # Display the scatter plot for predicted vs actual values #used clear_figure to clear the plot once displayed to avoid conflict - st.subheader("Model Prediction Plot") - st.pyplot(prediction_plot, clear_figure=True) - - st.subheader("Error Plot") - st.pyplot(error_plot, clear_figure=True) - - st.subheader("Model Performance Plot") - st.pyplot(performance_plot, clear_figure=True) + if prediction_plot!=None: + st.subheader("Model Prediction Plot") + st.pyplot(prediction_plot, clear_figure=True) + if error_plot!=None: + st.subheader("Error Plot") + st.pyplot(error_plot, clear_figure=True) + if performance_plot!=None: + st.subheader("Model Performance Plot") + st.pyplot(performance_plot, clear_figure=True)