diff --git a/models/Dementia Prediction Model/model.py b/models/Dementia Prediction Model/model.py new file mode 100644 index 00000000..d18b6a0b --- /dev/null +++ b/models/Dementia Prediction Model/model.py @@ -0,0 +1,31 @@ +import pandas as pd +from sklearn.model_selection import train_test_split +from sklearn.linear_model import LinearRegression +import joblib + +class CarPriceModel: + def __init__(self): + self.model = LinearRegression() + + def load_data(self, filepath): + data = pd.read_csv(filepath) + return data + + def preprocess_data(self, data): + # Assuming 'price' is the target column and the rest are features + X = data.drop('price', axis=1) # Replace 'price' with the actual target column name + y = data['price'] # Replace 'price' with the actual target column name + return train_test_split(X, y, test_size=0.2, random_state=42) + + def train(self, X_train, y_train): + self.model.fit(X_train, y_train) + + def save_model(self, model_path): + joblib.dump(self.model, model_path) + +if __name__ == "__main__": + car_model = CarPriceModel() + data = car_model.load_data('data/cleaned_car_data.csv') # Adjust the path to your dataset + X_train, X_test, y_train, y_test = car_model.preprocess_data(data) + car_model.train(X_train, y_train) + car_model.save_model('saved_models/car_price_model.pkl') \ No newline at end of file diff --git a/models/Dementia Prediction Model/modelevalution.py b/models/Dementia Prediction Model/modelevalution.py new file mode 100644 index 00000000..08e0b032 --- /dev/null +++ b/models/Dementia Prediction Model/modelevalution.py @@ -0,0 +1,21 @@ +import joblib +import pandas as pd +from sklearn.metrics import mean_squared_error, r2_score + +class ModelEvaluator: + def __init__(self, model_path): + self.model = joblib.load(model_path) + + def evaluate(self, X_test, y_test): + predictions = self.model.predict(X_test) + mse = mean_squared_error(y_test, predictions) + r2 = r2_score(y_test, predictions) + print("Mean Squared Error:", mse) + print("R^2 Score:", r2) + +if __name__ == "__main__": + data = pd.read_csv('data/cleaned_car_data.csv') # Load your test data + X_test = data.drop('price', axis=1) # Replace 'price' with the actual target column name + y_test = data['price'] # Replace 'price' with the actual target column name + evaluator = ModelEvaluator('saved_models/car_price_model.pkl') + evaluator.evaluate(X_test, y_test) \ No newline at end of file diff --git a/models/Dementia Prediction Model/notebook/dementia-prediction-using-different-ml-model.ipynb b/models/Dementia Prediction Model/notebook/dementia-prediction-using-different-ml-model.ipynb new file mode 100644 index 00000000..990a4df2 --- /dev/null +++ b/models/Dementia Prediction Model/notebook/dementia-prediction-using-different-ml-model.ipynb @@ -0,0 +1,3276 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.045864, + "end_time": "2021-02-21T05:24:40.532878", + "exception": false, + "start_time": "2021-02-21T05:24:40.487014", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Importing Libs" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", + "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", + "execution": { + "iopub.execute_input": "2021-02-21T05:24:40.625317Z", + "iopub.status.busy": "2021-02-21T05:24:40.624657Z", + "iopub.status.idle": "2021-02-21T05:24:41.846215Z", + "shell.execute_reply": "2021-02-21T05:24:41.846810Z" + }, + "papermill": { + "duration": 1.269347, + "end_time": "2021-02-21T05:24:41.847025", + "exception": false, + "start_time": "2021-02-21T05:24:40.577678", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import pandas as pd # used to load, manipulate the data and for one-hot encoding\n", + "import numpy as np # data manipulation\n", + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.colors as colors\n", + "from sklearn.utils import resample # for downsample the dataset\n", + "from sklearn.model_selection import train_test_split # for splitting the dataset into train and test split\n", + "from sklearn.preprocessing import scale # scale and center the data\n", + "from sklearn.svm import SVC # will make a SVM for classification\n", + "from sklearn.model_selection import GridSearchCV # will do the cross validation\n", + "from sklearn.metrics import plot_confusion_matrix # will draw the confusion matrix\n", + "from sklearn.decomposition import PCA # to perform PCA to plot the data\n", + "from sklearn.impute import SimpleImputer\n", + "from sklearn.preprocessing import StandardScaler, MinMaxScaler\n", + "from sklearn.model_selection import cross_val_score\n", + "from sklearn.metrics import confusion_matrix, precision_score, accuracy_score, recall_score, roc_curve, auc\n", + "import seaborn as sns" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.044613, + "end_time": "2021-02-21T05:24:41.936713", + "exception": false, + "start_time": "2021-02-21T05:24:41.892100", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Load the data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0", + "_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a", + "execution": { + "iopub.execute_input": "2021-02-21T05:24:42.029211Z", + "iopub.status.busy": "2021-02-21T05:24:42.028542Z", + "iopub.status.idle": "2021-02-21T05:24:42.058687Z", + "shell.execute_reply": "2021-02-21T05:24:42.059215Z" + }, + "papermill": { + "duration": 0.077355, + "end_time": "2021-02-21T05:24:42.059378", + "exception": false, + "start_time": "2021-02-21T05:24:41.982023", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "data = pd.read_csv(\"../input/mri-and-alzheimers/oasis_longitudinal.csv\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.04491, + "end_time": "2021-02-21T05:24:42.149525", + "exception": false, + "start_time": "2021-02-21T05:24:42.104615", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Explore the data" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:42.244342Z", + "iopub.status.busy": "2021-02-21T05:24:42.243655Z", + "iopub.status.idle": "2021-02-21T05:24:42.249012Z", + "shell.execute_reply": "2021-02-21T05:24:42.248425Z" + }, + "papermill": { + "duration": 0.053142, + "end_time": "2021-02-21T05:24:42.249140", + "exception": false, + "start_time": "2021-02-21T05:24:42.195998", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "pd.set_option('display.max_columns', None) # will show the all columns with pandas dataframe\n", + "pd.set_option('display.max_rows', None) # will show the all rows with pandas dataframe" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:42.359739Z", + "iopub.status.busy": "2021-02-21T05:24:42.359068Z", + "iopub.status.idle": "2021-02-21T05:24:42.371999Z", + "shell.execute_reply": "2021-02-21T05:24:42.372499Z" + }, + "papermill": { + "duration": 0.078228, + "end_time": "2021-02-21T05:24:42.372636", + "exception": false, + "start_time": "2021-02-21T05:24:42.294408", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Subject IDMRI IDGroupVisitMR DelayM/FHandAgeEDUCSESMMSECDReTIVnWBVASF
0OAS2_0001OAS2_0001_MR1Nondemented10MR87142.027.00.019870.6960.883
1OAS2_0001OAS2_0001_MR2Nondemented2457MR88142.030.00.020040.6810.876
2OAS2_0002OAS2_0002_MR1Demented10MR7512NaN23.00.516780.7361.046
3OAS2_0002OAS2_0002_MR2Demented2560MR7612NaN28.00.517380.7131.010
4OAS2_0002OAS2_0002_MR3Demented31895MR8012NaN22.00.516980.7011.034
\n", + "
" + ], + "text/plain": [ + " Subject ID MRI ID Group Visit MR Delay M/F Hand Age EDUC \\\n", + "0 OAS2_0001 OAS2_0001_MR1 Nondemented 1 0 M R 87 14 \n", + "1 OAS2_0001 OAS2_0001_MR2 Nondemented 2 457 M R 88 14 \n", + "2 OAS2_0002 OAS2_0002_MR1 Demented 1 0 M R 75 12 \n", + "3 OAS2_0002 OAS2_0002_MR2 Demented 2 560 M R 76 12 \n", + "4 OAS2_0002 OAS2_0002_MR3 Demented 3 1895 M R 80 12 \n", + "\n", + " SES MMSE CDR eTIV nWBV ASF \n", + "0 2.0 27.0 0.0 1987 0.696 0.883 \n", + "1 2.0 30.0 0.0 2004 0.681 0.876 \n", + "2 NaN 23.0 0.5 1678 0.736 1.046 \n", + "3 NaN 28.0 0.5 1738 0.713 1.010 \n", + "4 NaN 22.0 0.5 1698 0.701 1.034 " + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.head()\n", + "# data.tail()\n", + "# data.size" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:42.468495Z", + "iopub.status.busy": "2021-02-21T05:24:42.467769Z", + "iopub.status.idle": "2021-02-21T05:24:42.473077Z", + "shell.execute_reply": "2021-02-21T05:24:42.472435Z" + }, + "papermill": { + "duration": 0.054333, + "end_time": "2021-02-21T05:24:42.473216", + "exception": false, + "start_time": "2021-02-21T05:24:42.418883", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(373, 15)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:42.571285Z", + "iopub.status.busy": "2021-02-21T05:24:42.570528Z", + "iopub.status.idle": "2021-02-21T05:24:42.586722Z", + "shell.execute_reply": "2021-02-21T05:24:42.585742Z" + }, + "papermill": { + "duration": 0.067194, + "end_time": "2021-02-21T05:24:42.586897", + "exception": false, + "start_time": "2021-02-21T05:24:42.519703", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "RangeIndex: 373 entries, 0 to 372\n", + "Data columns (total 15 columns):\n", + " # Column Non-Null Count Dtype \n", + "--- ------ -------------- ----- \n", + " 0 Subject ID 373 non-null object \n", + " 1 MRI ID 373 non-null object \n", + " 2 Group 373 non-null object \n", + " 3 Visit 373 non-null int64 \n", + " 4 MR Delay 373 non-null int64 \n", + " 5 M/F 373 non-null object \n", + " 6 Hand 373 non-null object \n", + " 7 Age 373 non-null int64 \n", + " 8 EDUC 373 non-null int64 \n", + " 9 SES 354 non-null float64\n", + " 10 MMSE 371 non-null float64\n", + " 11 CDR 373 non-null float64\n", + " 12 eTIV 373 non-null int64 \n", + " 13 nWBV 373 non-null float64\n", + " 14 ASF 373 non-null float64\n", + "dtypes: float64(5), int64(5), object(5)\n", + "memory usage: 43.8+ KB\n" + ] + } + ], + "source": [ + "data.info()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.046597, + "end_time": "2021-02-21T05:24:42.681879", + "exception": false, + "start_time": "2021-02-21T05:24:42.635282", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Converting Categorical Data to Numerical Data" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.046699, + "end_time": "2021-02-21T05:24:42.776013", + "exception": false, + "start_time": "2021-02-21T05:24:42.729314", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "When **inplace = True** , the data is modified in place, which means it will return nothing and the dataframe is now updated. \n", + "When **inplace = False** , which is the *default*, then the operation is performed and it returns a copy of the object. You then need to save it to something." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.046777, + "end_time": "2021-02-21T05:24:42.872082", + "exception": false, + "start_time": "2021-02-21T05:24:42.825305", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "set axis=0 for rows or, just put axis='rows' to access the rows\n", + "\n", + "set axis=1 for columns or, just put axis='columns' to access the columns" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:42.982688Z", + "iopub.status.busy": "2021-02-21T05:24:42.981824Z", + "iopub.status.idle": "2021-02-21T05:24:42.985482Z", + "shell.execute_reply": "2021-02-21T05:24:42.986097Z" + }, + "papermill": { + "duration": 0.067422, + "end_time": "2021-02-21T05:24:42.986242", + "exception": false, + "start_time": "2021-02-21T05:24:42.918820", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "RangeIndex: 373 entries, 0 to 372\n", + "Data columns (total 15 columns):\n", + " # Column Non-Null Count Dtype \n", + "--- ------ -------------- ----- \n", + " 0 Subject ID 373 non-null object \n", + " 1 MRI ID 373 non-null object \n", + " 2 Group 373 non-null int64 \n", + " 3 Visit 373 non-null int64 \n", + " 4 MR Delay 373 non-null int64 \n", + " 5 M/F 373 non-null int64 \n", + " 6 Hand 373 non-null object \n", + " 7 Age 373 non-null int64 \n", + " 8 EDUC 373 non-null int64 \n", + " 9 SES 354 non-null float64\n", + " 10 MMSE 371 non-null float64\n", + " 11 CDR 373 non-null float64\n", + " 12 eTIV 373 non-null int64 \n", + " 13 nWBV 373 non-null float64\n", + " 14 ASF 373 non-null float64\n", + "dtypes: float64(5), int64(7), object(3)\n", + "memory usage: 43.8+ KB\n" + ] + } + ], + "source": [ + "data['M/F'] = [1 if each == \"M\" else 0 for each in data['M/F']]\n", + "data['Group'] = [1 if each == \"Demented\" or each == \"Converted\" else 0 for each in data['Group']]\n", + "# data['Group'] = data['Group'].replace(['Converted'], ['Demented']) # Target variable\n", + "# data['Group'] = data['Group'].replace(['Demented', 'Nondemented'], [1,0]) # Target variable\n", + "data.info()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.046981, + "end_time": "2021-02-21T05:24:43.080526", + "exception": false, + "start_time": "2021-02-21T05:24:43.033545", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "Note: Based on the given data **CDR** is used to tell what the condition of the patient meaning, does the patient has any dementia or, not.\n", + "\n", + "CDR Value Meaning:\n", + "\n", + "* 0 ---> Normal\n", + "* 0.5 ---> Very Mild Dementia\n", + "* 1 ---> Mild Dementia\n", + "* 2 ---> Moderate Dementia\n", + "* 3 ---> Severe Dementia" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.047171, + "end_time": "2021-02-21T05:24:43.175738", + "exception": false, + "start_time": "2021-02-21T05:24:43.128567", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Correlation Between Attributes" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:43.274629Z", + "iopub.status.busy": "2021-02-21T05:24:43.273593Z", + "iopub.status.idle": "2021-02-21T05:24:43.287746Z", + "shell.execute_reply": "2021-02-21T05:24:43.287173Z" + }, + "papermill": { + "duration": 0.064432, + "end_time": "2021-02-21T05:24:43.287873", + "exception": false, + "start_time": "2021-02-21T05:24:43.223441", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Group 1.000000\n", + "CDR 0.778049\n", + "M/F 0.222146\n", + "SES 0.062463\n", + "ASF 0.032495\n", + "Age -0.005941\n", + "eTIV -0.042700\n", + "Visit -0.095507\n", + "MR Delay -0.120638\n", + "EDUC -0.193060\n", + "nWBV -0.311346\n", + "MMSE -0.524775\n", + "Name: Group, dtype: float64" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "correlation_matrix = data.corr()\n", + "data_corr = correlation_matrix['Group'].sort_values(ascending=False)\n", + "data_corr" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:43.387454Z", + "iopub.status.busy": "2021-02-21T05:24:43.386736Z", + "iopub.status.idle": "2021-02-21T05:24:45.327091Z", + "shell.execute_reply": "2021-02-21T05:24:45.326421Z" + }, + "papermill": { + "duration": 1.99135, + "end_time": "2021-02-21T05:24:45.327222", + "exception": false, + "start_time": "2021-02-21T05:24:43.335872", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[,\n", + " ,\n", + " ,\n", + " ,\n", + " ],\n", + " [,\n", + " ,\n", + " ,\n", + " ,\n", + " ],\n", + " [,\n", + " ,\n", + " ,\n", + " ,\n", + " ],\n", + " [,\n", + " ,\n", + " ,\n", + " ,\n", + " ],\n", + " [,\n", + " ,\n", + " ,\n", + " ,\n", + " ]],\n", + " dtype=object)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "from pandas.plotting import scatter_matrix\n", + "\n", + "attributes = [\"Group\", \"CDR\", \"M/F\", \"SES\", \"ASF\"]\n", + "\n", + "scatter_matrix(data[attributes], figsize=(15, 11), alpha=0.3)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:45.433703Z", + "iopub.status.busy": "2021-02-21T05:24:45.433017Z", + "iopub.status.idle": "2021-02-21T05:24:47.540106Z", + "shell.execute_reply": "2021-02-21T05:24:47.540628Z" + }, + "papermill": { + "duration": 2.163236, + "end_time": "2021-02-21T05:24:47.540769", + "exception": false, + "start_time": "2021-02-21T05:24:45.377533", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + " \n", + " " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import plotly.express as px\n", + "\n", + "fig = px.scatter(data, x='Group', y='SES', color='Group')\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:47.650865Z", + "iopub.status.busy": "2021-02-21T05:24:47.649810Z", + "iopub.status.idle": "2021-02-21T05:24:47.721817Z", + "shell.execute_reply": "2021-02-21T05:24:47.722413Z" + }, + "papermill": { + "duration": 0.129299, + "end_time": "2021-02-21T05:24:47.722563", + "exception": false, + "start_time": "2021-02-21T05:24:47.593264", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import plotly.express as px\n", + "\n", + "fig = px.scatter(data, x='Group', y='Age', color='Group')\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:47.832887Z", + "iopub.status.busy": "2021-02-21T05:24:47.831861Z", + "iopub.status.idle": "2021-02-21T05:24:47.904040Z", + "shell.execute_reply": "2021-02-21T05:24:47.904522Z" + }, + "papermill": { + "duration": 0.130189, + "end_time": "2021-02-21T05:24:47.904674", + "exception": false, + "start_time": "2021-02-21T05:24:47.774485", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import plotly.express as px\n", + "\n", + "fig = px.scatter(data, x='Group', y='ASF', color='Group')\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.053096, + "end_time": "2021-02-21T05:24:48.011350", + "exception": false, + "start_time": "2021-02-21T05:24:47.958254", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Checking For Missig/Null Values" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:48.126795Z", + "iopub.status.busy": "2021-02-21T05:24:48.125994Z", + "iopub.status.idle": "2021-02-21T05:24:48.130278Z", + "shell.execute_reply": "2021-02-21T05:24:48.129746Z" + }, + "papermill": { + "duration": 0.065537, + "end_time": "2021-02-21T05:24:48.130397", + "exception": false, + "start_time": "2021-02-21T05:24:48.064860", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Subject ID 0\n", + "MRI ID 0\n", + "Group 0\n", + "Visit 0\n", + "MR Delay 0\n", + "M/F 0\n", + "Hand 0\n", + "Age 0\n", + "EDUC 0\n", + "SES 19\n", + "MMSE 2\n", + "CDR 0\n", + "eTIV 0\n", + "nWBV 0\n", + "ASF 0\n", + "dtype: int64" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.isnull().sum()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.053518, + "end_time": "2021-02-21T05:24:48.237952", + "exception": false, + "start_time": "2021-02-21T05:24:48.184434", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### Taking median values for the missing values of MMSE" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:48.355837Z", + "iopub.status.busy": "2021-02-21T05:24:48.355128Z", + "iopub.status.idle": "2021-02-21T05:24:48.357901Z", + "shell.execute_reply": "2021-02-21T05:24:48.358483Z" + }, + "papermill": { + "duration": 0.066666, + "end_time": "2021-02-21T05:24:48.358626", + "exception": false, + "start_time": "2021-02-21T05:24:48.291960", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Subject ID 0\n", + "MRI ID 0\n", + "Group 0\n", + "Visit 0\n", + "MR Delay 0\n", + "M/F 0\n", + "Hand 0\n", + "Age 0\n", + "EDUC 0\n", + "SES 19\n", + "MMSE 0\n", + "CDR 0\n", + "eTIV 0\n", + "nWBV 0\n", + "ASF 0\n", + "dtype: int64" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "median = data['MMSE'].median()\n", + "data['MMSE'].fillna(median, inplace=True)\n", + "data.isnull().sum()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.054757, + "end_time": "2021-02-21T05:24:48.468036", + "exception": false, + "start_time": "2021-02-21T05:24:48.413279", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### Taking median values for the missing values of SES" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:48.580903Z", + "iopub.status.busy": "2021-02-21T05:24:48.580254Z", + "iopub.status.idle": "2021-02-21T05:24:48.589172Z", + "shell.execute_reply": "2021-02-21T05:24:48.589795Z" + }, + "papermill": { + "duration": 0.067076, + "end_time": "2021-02-21T05:24:48.589954", + "exception": false, + "start_time": "2021-02-21T05:24:48.522878", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Subject ID 0\n", + "MRI ID 0\n", + "Group 0\n", + "Visit 0\n", + "MR Delay 0\n", + "M/F 0\n", + "Hand 0\n", + "Age 0\n", + "EDUC 0\n", + "SES 0\n", + "MMSE 0\n", + "CDR 0\n", + "eTIV 0\n", + "nWBV 0\n", + "ASF 0\n", + "dtype: int64" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "median = data['SES'].median()\n", + "data['SES'].fillna(median, inplace=True)\n", + "data.isnull().sum()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.054489, + "end_time": "2021-02-21T05:24:48.700670", + "exception": false, + "start_time": "2021-02-21T05:24:48.646181", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Train-Test Split" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.055984, + "end_time": "2021-02-21T05:24:48.812635", + "exception": false, + "start_time": "2021-02-21T05:24:48.756651", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Prepare the data for X and y where, \n", + "\n", + "1. X = The columns/features for **making the prediction**\n", + "2. y = The **predicted value**" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:48.926932Z", + "iopub.status.busy": "2021-02-21T05:24:48.926288Z", + "iopub.status.idle": "2021-02-21T05:24:48.933312Z", + "shell.execute_reply": "2021-02-21T05:24:48.932717Z" + }, + "papermill": { + "duration": 0.065077, + "end_time": "2021-02-21T05:24:48.933436", + "exception": false, + "start_time": "2021-02-21T05:24:48.868359", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "y = data['Group'].values\n", + "X = data[['M/F', 'Age', 'EDUC', 'SES', 'MMSE', 'eTIV', 'nWBV', 'ASF']]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.055011, + "end_time": "2021-02-21T05:24:49.044676", + "exception": false, + "start_time": "2021-02-21T05:24:48.989665", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### Train-Test distribution Without Stratified Sampling" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:49.159712Z", + "iopub.status.busy": "2021-02-21T05:24:49.158704Z", + "iopub.status.idle": "2021-02-21T05:24:49.174698Z", + "shell.execute_reply": "2021-02-21T05:24:49.175197Z" + }, + "papermill": { + "duration": 0.075266, + "end_time": "2021-02-21T05:24:49.175352", + "exception": false, + "start_time": "2021-02-21T05:24:49.100086", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "In Training Split:\n", + "0 158\n", + "1 140\n", + "Name: 0, dtype: int64\n", + "\n", + "In Testing Split:\n", + "1 43\n", + "0 32\n", + "Name: 0, dtype: int64\n" + ] + } + ], + "source": [ + "# by default test_size= 0.25\n", + "X_trainval, X_test, y_trainval, y_test = train_test_split(X, y, test_size= 0.20, random_state=42)\n", + "\n", + "df_ytrain = pd.DataFrame(y_trainval)\n", + "df_ytest = pd.DataFrame(y_test)\n", + "\n", + "print('In Training Split:')\n", + "print(df_ytrain[0].value_counts())\n", + "\n", + "print('\\nIn Testing Split:')\n", + "print(df_ytest[0].value_counts())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.055463, + "end_time": "2021-02-21T05:24:49.286592", + "exception": false, + "start_time": "2021-02-21T05:24:49.231129", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### With Stratified Sampling" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:49.402123Z", + "iopub.status.busy": "2021-02-21T05:24:49.401433Z", + "iopub.status.idle": "2021-02-21T05:24:49.418062Z", + "shell.execute_reply": "2021-02-21T05:24:49.417434Z" + }, + "papermill": { + "duration": 0.075377, + "end_time": "2021-02-21T05:24:49.418177", + "exception": false, + "start_time": "2021-02-21T05:24:49.342800", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "In Training Split:\n", + "0 152\n", + "1 146\n", + "Name: 0, dtype: int64\n", + "\n", + "In Testing Split:\n", + "0 38\n", + "1 37\n", + "Name: 0, dtype: int64\n" + ] + } + ], + "source": [ + "# by default test_size= 0.25\n", + "X_trainval, X_test, y_trainval, y_test = train_test_split(X, y, test_size= 0.20, random_state=42, stratify=y)\n", + "\n", + "\n", + "df_ytrain = pd.DataFrame(y_trainval)\n", + "df_ytest = pd.DataFrame(y_test)\n", + "\n", + "print('In Training Split:')\n", + "print(df_ytrain[0].value_counts())\n", + "\n", + "print('\\nIn Testing Split:')\n", + "print(df_ytest[0].value_counts())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.056065, + "end_time": "2021-02-21T05:24:49.530652", + "exception": false, + "start_time": "2021-02-21T05:24:49.474587", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### Scale the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:49.651835Z", + "iopub.status.busy": "2021-02-21T05:24:49.650741Z", + "iopub.status.idle": "2021-02-21T05:24:49.660072Z", + "shell.execute_reply": "2021-02-21T05:24:49.659484Z" + }, + "papermill": { + "duration": 0.073448, + "end_time": "2021-02-21T05:24:49.660201", + "exception": false, + "start_time": "2021-02-21T05:24:49.586753", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# here StandardScaler() means z = (x - u) / s\n", + "scaler = StandardScaler().fit(X_trainval)\n", + "#scaler = MinMaxScaler().fit(X_trainval)\n", + "X_trainval_scaled = scaler.transform(X_trainval)\n", + "X_test_scaled = scaler.transform(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:49.777554Z", + "iopub.status.busy": "2021-02-21T05:24:49.776541Z", + "iopub.status.idle": "2021-02-21T05:24:49.782686Z", + "shell.execute_reply": "2021-02-21T05:24:49.783257Z" + }, + "papermill": { + "duration": 0.066296, + "end_time": "2021-02-21T05:24:49.783400", + "exception": false, + "start_time": "2021-02-21T05:24:49.717104", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[-0.87966444, 0.38449006, -0.87500081, ..., -0.30880564,\n", + " 0.16961408, 0.21548547],\n", + " [ 1.13679712, 1.04832574, -0.87500081, ..., 1.23805919,\n", + " -0.67676996, -1.21558429],\n", + " [-0.87966444, -0.9431813 , -1.22928103, ..., -1.08511327,\n", + " 0.4605586 , 1.14751551],\n", + " ...,\n", + " [-0.87966444, -0.01381135, 1.2506805 , ..., -0.92985174,\n", + " 0.01091708, 0.94202857],\n", + " [-0.87966444, -1.20871557, -0.87500081, ..., -0.00978344,\n", + " 0.5663566 , -0.10742258],\n", + " [-0.87966444, 0.11895579, 1.2506805 , ..., -1.38413547,\n", + " 0.4605586 , 1.56582821]])" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X_trainval_scaled" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:49.907580Z", + "iopub.status.busy": "2021-02-21T05:24:49.906515Z", + "iopub.status.idle": "2021-02-21T05:24:49.938315Z", + "shell.execute_reply": "2021-02-21T05:24:49.937670Z" + }, + "papermill": { + "duration": 0.09737, + "end_time": "2021-02-21T05:24:49.938439", + "exception": false, + "start_time": "2021-02-21T05:24:49.841069", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
M/FAgeEDUCSESMMSEeTIVnWBVASF
count298.000000298.000000298.000000298.000000298.000000298.000000298.000000298.000000
mean0.43624277.10402714.4697992.48993327.3557051483.7013420.7305871.198638
std0.4967527.5446542.8273721.1168593.689231174.1926490.0378710.136491
min0.00000060.0000006.0000001.0000004.0000001106.0000000.6440000.883000
25%0.00000071.25000012.0000002.00000027.0000001357.0000000.6992501.107250
50%0.00000077.00000014.0000002.00000029.0000001462.0000000.7310001.200500
75%1.00000082.00000016.0000003.00000030.0000001585.2500000.7570001.293000
max1.00000096.00000023.0000005.00000030.0000001987.0000000.8370001.587000
\n", + "
" + ], + "text/plain": [ + " M/F Age EDUC SES MMSE \\\n", + "count 298.000000 298.000000 298.000000 298.000000 298.000000 \n", + "mean 0.436242 77.104027 14.469799 2.489933 27.355705 \n", + "std 0.496752 7.544654 2.827372 1.116859 3.689231 \n", + "min 0.000000 60.000000 6.000000 1.000000 4.000000 \n", + "25% 0.000000 71.250000 12.000000 2.000000 27.000000 \n", + "50% 0.000000 77.000000 14.000000 2.000000 29.000000 \n", + "75% 1.000000 82.000000 16.000000 3.000000 30.000000 \n", + "max 1.000000 96.000000 23.000000 5.000000 30.000000 \n", + "\n", + " eTIV nWBV ASF \n", + "count 298.000000 298.000000 298.000000 \n", + "mean 1483.701342 0.730587 1.198638 \n", + "std 174.192649 0.037871 0.136491 \n", + "min 1106.000000 0.644000 0.883000 \n", + "25% 1357.000000 0.699250 1.107250 \n", + "50% 1462.000000 0.731000 1.200500 \n", + "75% 1585.250000 0.757000 1.293000 \n", + "max 1987.000000 0.837000 1.587000 " + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X_trainval.describe()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.057412, + "end_time": "2021-02-21T05:24:50.054067", + "exception": false, + "start_time": "2021-02-21T05:24:49.996655", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Data Visualization" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:50.175069Z", + "iopub.status.busy": "2021-02-21T05:24:50.174376Z", + "iopub.status.idle": "2021-02-21T05:24:51.959428Z", + "shell.execute_reply": "2021-02-21T05:24:51.958796Z" + }, + "papermill": { + "duration": 1.847829, + "end_time": "2021-02-21T05:24:51.959546", + "exception": false, + "start_time": "2021-02-21T05:24:50.111717", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "X_trainval.hist(bins=30, figsize=(20,15))\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:52.090627Z", + "iopub.status.busy": "2021-02-21T05:24:52.089685Z", + "iopub.status.idle": "2021-02-21T05:24:52.168509Z", + "shell.execute_reply": "2021-02-21T05:24:52.167818Z" + }, + "papermill": { + "duration": 0.150146, + "end_time": "2021-02-21T05:24:52.168625", + "exception": false, + "start_time": "2021-02-21T05:24:52.018479", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import plotly.express as px\n", + "\n", + "x = ['M/F', 'Age', 'EDUC', 'SES', 'MMSE', 'eTIV', 'nWBV', 'ASF']\n", + "\n", + "fig = px.histogram(X_trainval, x='eTIV', nbins=50)\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:52.349421Z", + "iopub.status.busy": "2021-02-21T05:24:52.321695Z", + "iopub.status.idle": "2021-02-21T05:24:52.362716Z", + "shell.execute_reply": "2021-02-21T05:24:52.362055Z" + }, + "papermill": { + "duration": 0.133606, + "end_time": "2021-02-21T05:24:52.362833", + "exception": false, + "start_time": "2021-02-21T05:24:52.229227", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import plotly.express as px\n", + "\n", + "x = ['M/F', 'Age', 'EDUC', 'SES', 'MMSE', 'eTIV', 'nWBV', 'ASF']\n", + "\n", + "fig = px.scatter(X_trainval, x='eTIV')\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.060847, + "end_time": "2021-02-21T05:24:52.484599", + "exception": false, + "start_time": "2021-02-21T05:24:52.423752", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# SVM" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:52.614683Z", + "iopub.status.busy": "2021-02-21T05:24:52.613922Z", + "iopub.status.idle": "2021-02-21T05:24:52.791187Z", + "shell.execute_reply": "2021-02-21T05:24:52.790597Z" + }, + "papermill": { + "duration": 0.245707, + "end_time": "2021-02-21T05:24:52.791313", + "exception": false, + "start_time": "2021-02-21T05:24:52.545606", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "clf_svm = SVC(random_state=42)\n", + "clf_svm.fit(X_trainval_scaled, y_trainval)\n", + "\n", + "# for test there are 94 cases\n", + "plot_confusion_matrix(clf_svm, \n", + " X_test_scaled, \n", + " y_test, \n", + " values_format='d', \n", + " display_labels=['Nondemented', 'Demented'])" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:52.925317Z", + "iopub.status.busy": "2021-02-21T05:24:52.924620Z", + "iopub.status.idle": "2021-02-21T05:24:52.935820Z", + "shell.execute_reply": "2021-02-21T05:24:52.935250Z" + }, + "papermill": { + "duration": 0.08155, + "end_time": "2021-02-21T05:24:52.935942", + "exception": false, + "start_time": "2021-02-21T05:24:52.854392", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train accuracy 0.8590604026845637\n", + "Test accuracy 0.7466666666666667\n", + "Test recall 0.5945945945945946\n", + "Test AUC 0.744665718349929\n" + ] + } + ], + "source": [ + "train_score = 0\n", + "test_score = 0\n", + "test_recall = 0\n", + "test_auc = 0\n", + "\n", + "train_score = clf_svm.score(X_trainval_scaled, y_trainval)\n", + "test_score = clf_svm.score(X_test_scaled, y_test)\n", + "y_predict = clf_svm.predict(X_test_scaled)\n", + "\n", + "test_recall = recall_score(y_test, y_predict)\n", + "fpr, tpr, thresholds = roc_curve(y_test, y_predict)\n", + "test_auc = auc(fpr, tpr)\n", + "\n", + "\n", + "print(\"Train accuracy \", train_score)\n", + "print(\"Test accuracy \", test_score)\n", + "print(\"Test recall\", test_recall)\n", + "print(\"Test AUC\", test_auc)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.062406, + "end_time": "2021-02-21T05:24:53.061149", + "exception": false, + "start_time": "2021-02-21T05:24:52.998743", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### Optimize parameters(Finetuning) --> GridSearchCV() for SVM" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:53.196144Z", + "iopub.status.busy": "2021-02-21T05:24:53.195430Z", + "iopub.status.idle": "2021-02-21T05:25:04.937936Z", + "shell.execute_reply": "2021-02-21T05:25:04.938471Z" + }, + "papermill": { + "duration": 11.814617, + "end_time": "2021-02-21T05:25:04.938646", + "exception": false, + "start_time": "2021-02-21T05:24:53.124029", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'C': 6, 'gamma': 1, 'kernel': 'rbf'}\n" + ] + } + ], + "source": [ + "# Normally, C = 1 and gamma = 'scale' are default values\n", + "# C controls how wide the margin will be with respect to how many misclassification we are allowing\n", + "# C is increasing --> reduce the size of the margin and fewer misclassification and vice versa\n", + "param_grid = [\n", + " {'C': [0.5, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 30, 50, 80, 100],\n", + " 'gamma': ['scale', 0.5, 1, 0.1, 0.01, 0.001, 0.0001, 0.00001],\n", + " 'kernel': ['rbf', 'linear', 'poly', 'sigmoid']},\n", + "]\n", + "\n", + "optimal_params = GridSearchCV(SVC(),\n", + " param_grid,\n", + " cv=5, # we are taking 5-fold as in k-fold cross validation\n", + " scoring='accuracy', # try the other scoring if have time\n", + " verbose=0,\n", + " n_jobs=-1)\n", + "\n", + "optimal_params.fit(X_trainval_scaled, y_trainval)\n", + "print(optimal_params.best_params_)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:25:05.073634Z", + "iopub.status.busy": "2021-02-21T05:25:05.072673Z", + "iopub.status.idle": "2021-02-21T05:25:05.076294Z", + "shell.execute_reply": "2021-02-21T05:25:05.075675Z" + }, + "papermill": { + "duration": 0.074232, + "end_time": "2021-02-21T05:25:05.076406", + "exception": false, + "start_time": "2021-02-21T05:25:05.002174", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "C = optimal_params.best_params_['C']\n", + "gamma = optimal_params.best_params_['gamma']\n", + "kernel = optimal_params.best_params_['kernel']" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:25:05.207035Z", + "iopub.status.busy": "2021-02-21T05:25:05.206326Z", + "iopub.status.idle": "2021-02-21T05:25:05.382139Z", + "shell.execute_reply": "2021-02-21T05:25:05.382632Z" + }, + "papermill": { + "duration": 0.242652, + "end_time": "2021-02-21T05:25:05.382801", + "exception": false, + "start_time": "2021-02-21T05:25:05.140149", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "clf_svm = SVC(random_state=42, C=C, gamma=gamma, kernel=kernel)\n", + "clf_svm.fit(X_trainval_scaled, y_trainval)\n", + "\n", + "plot_confusion_matrix(clf_svm, \n", + " X_test_scaled, \n", + " y_test, \n", + " values_format='d', \n", + " display_labels=['Nondemented', 'Demented'])" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:25:05.517061Z", + "iopub.status.busy": "2021-02-21T05:25:05.516041Z", + "iopub.status.idle": "2021-02-21T05:25:05.535028Z", + "shell.execute_reply": "2021-02-21T05:25:05.534191Z" + }, + "papermill": { + "duration": 0.08746, + "end_time": "2021-02-21T05:25:05.535192", + "exception": false, + "start_time": "2021-02-21T05:25:05.447732", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train accuracy 1.0\n", + "Test accuracy 0.92\n", + "Test recall 0.918918918918919\n", + "Test AUC 0.9199857752489331\n" + ] + } + ], + "source": [ + "train_score = 0\n", + "test_score = 0\n", + "test_recall = 0\n", + "test_auc = 0\n", + "\n", + "train_score = clf_svm.score(X_trainval_scaled, y_trainval)\n", + "test_score = clf_svm.score(X_test_scaled, y_test)\n", + "y_predict = clf_svm.predict(X_test_scaled)\n", + "\n", + "test_recall = recall_score(y_test, y_predict)\n", + "svm_fpr, svm_tpr, thresholds = roc_curve(y_test, y_predict)\n", + "test_auc = auc(svm_fpr, svm_tpr)\n", + "\n", + "\n", + "print(\"Train accuracy \", train_score)\n", + "print(\"Test accuracy \", test_score)\n", + "print(\"Test recall\", test_recall)\n", + "print(\"Test AUC\", test_auc)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.064261, + "end_time": "2021-02-21T05:25:05.665444", + "exception": false, + "start_time": "2021-02-21T05:25:05.601183", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Random Forest" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:25:05.798235Z", + "iopub.status.busy": "2021-02-21T05:25:05.797535Z", + "iopub.status.idle": "2021-02-21T05:25:05.857124Z", + "shell.execute_reply": "2021-02-21T05:25:05.856511Z" + }, + "papermill": { + "duration": 0.126931, + "end_time": "2021-02-21T05:25:05.857258", + "exception": false, + "start_time": "2021-02-21T05:25:05.730327", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "from sklearn.ensemble import RandomForestClassifier" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:25:06.011800Z", + "iopub.status.busy": "2021-02-21T05:25:06.001387Z", + "iopub.status.idle": "2021-02-21T05:25:06.381618Z", + "shell.execute_reply": "2021-02-21T05:25:06.381041Z" + }, + "papermill": { + "duration": 0.459083, + "end_time": "2021-02-21T05:25:06.381735", + "exception": false, + "start_time": "2021-02-21T05:25:05.922652", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXMAAAEKCAYAAADgl7WbAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAf9klEQVR4nO3deZwcVb338c93JkMSSAgEkhiVXQ0ikgABBSRPWEREvYIPuFxUFB5ZBAIKeNHrS0HlucgibmxBkUVAwACyKKtCWCUkhEASAggRlUhI2BKWJDP9u3/UGWnizHTV0DO9zPfNq16prq5z6tfdw69Pnzp1ShGBmZk1tpZaB2BmZm+dk7mZWRNwMjczawJO5mZmTcDJ3MysCTiZm5k1ASdzM7MakTRE0v2SHpI0V9KJafsJkv4haXZa9qpYl8eZm5nVhiQBa0XEckltwF3AUcCewPKIOC1vXYP6KEYzM6sgstb08vSwLS29amE7mdeR9Ue2xsYbtNU6DCvgsTlr1joEK2gZLyyJiFG9Lf+RXdaKpc935Np35pwVc4HXyzZNjYip5ftIagVmAu8CzoyIP0v6KHCEpC8CDwDHRMQLPR3L3Sx1ZOL4IXH/TRvUOgwr4CNvn1DrEKygW+O3MyNiYm/LZ/+fbphr39axj+c+lqR1gKuBI4HngCVkrfTvA2Mj4sCeyvsEqJlZAQGUcv5XqN6IF4HbgT0j4tmI6IiIEnAesH2l8k7mZmYFBMGq6Mi1VCJpVGqRI2kosDvwqKSxZbvtAzxSqS73mZuZFVS01d2DscCFqd+8BbgiIq6XdLGkCWQ/BBYCh1SqyMnczKyAIOio0rnGiJgDbN3F9i8UrcvJ3MysoFLvRg/2KSdzM7MCAuhwMjcza3xumZuZNbgAVtXh9TlO5mZmBQThbhYzs4YX0FF/udzJ3MysiOwK0PrjZG5mVojoQLUO4t84mZuZFZCdAHUyNzNraNk4cydzM7OGV3LL3MyssbllbmbWBALRUYezhzuZm5kV5G4WM7MGF4iV0VrrMP6Nk7mZWQHZRUPuZjEza3g+AWpm1uAiREe4ZW5m1vBKbpmbmTW27ARo/aXO+ovIzKyO+QSomVmT6PA4czOzxuYrQM3MmkTJo1nMzBpbNtGWk7mZWUMLxKoqXc4vaQgwHRhMlo9/GxHflTQSuBzYGFgIfDoiXuiprvr7ejEzq2MR0BEtuZYcVgC7RsR4YAKwp6QPAscDt0XEu4Hb0uMeOZmbmRUiSjmXSiKzPD1sS0sAnwQuTNsvBPauVJe7WczMCggocjn/+pIeKHs8NSKmlu8gqRWYCbwLODMi/ixpTEQsAoiIRZJGVzqQk7mZWUEFToAuiYiJPe0QER3ABEnrAFdL2rI3MTmZm5kVEKhPbk4RES9Kuh3YE3hW0tjUKh8LLK5U3n3mZmYFBLAqBuVaKpE0KrXIkTQU2B14FLgWOCDtdgDwu0p1uWVuZlaIqjmf+VjgwtRv3gJcERHXS7oXuELSQcDTwH6VKnIyNzMrIKjeFaARMQfYuovtS4HditTlZG5mVpDvNGRm1uAi5LlZzMwaXXYCtDqX81eTk7mZWSG+B6iZWcPLToC6z9zMrOF5ClwzswbXV1eAvlVO5mZmBfmGzmZmDS4CVpWczM3MGlrWzeJkbmbW8HwFqDW9la+LYz71LlatbKGjHXb+2Et88bh/cvFpb+MPl45kxMgOAL78zWfYfrdlNY7WVtc2uMTpVz1B2xpB66DgzhvW4eLT3lbrsOrKgBuaKCmAH0XEMenxscCwiDihCnVfAFwfEb99q3X14th7A49FxLyC5ZZHxLA+CqtutA0OTrnyLwxdq0T7Kvj63u9mu11fBmCfrzzHfoc9V+MIrSerVohv7LcZr7/aSuug4EfXPMGMPw7n0Vlr1Tq0OlKf3Sx9GdEK4FOS1u/DY9TC3sAWtQ6iXkkwdK0SAO2rRMcqofprxFi3xOuvZpeqD2oLWtuCiBqHVIeqdQ/QaurLZN4OTAW+tvoTkjaSdJukOenfDdP2CyT9VNI9kp6UtG/aLkk/lzRP0g3A6LK6tpV0h6SZkm5Kd+VA0u2SzpA0XdJ8SdtJukrS45J+UFb+85LulzRb0rlpXmEkLZd0kqSHJN0naYykHYH/AE5N+2+WlhvT8e+UtHkqv4mkeyXNkPT9PnuX61BHBxy2+zg+s9WWbD1pGZtv8yoA1/1qFIfuNo7Tv7YBy16sv7ktLNPSEpx1ywIunzOXB6cPY8GDbpWXy0aztOZa+lNf/1Y4E9hf0ojVtv8cuCgitgIuAX5a9txY4EPAx4GT07Z9gHHA+4GvADsCSGoDfgbsGxHbAucDJ5XVtTIiJgHnkN2p43BgS+BLktaT9F7gM8BOETEB6AD2T2XXAu6LiPHAdOArEXEP2R1AjouICRHxF7IvrCPT8Y8FzkrlfwKcHRHbAf/s7g2SdLCkByQ98NzSjm7fyEbS2gpn37qAS2bOY8HsNVn46BA+fsASfnXvPM66ZQEjx6xi6olvr3WY1o1SSXz1w+PYf9stGDfhVTYa91qtQ6ornRcN5Vn6U58m84h4GbgImLLaUzsAl6b1i8mSd6drIqKU+qTHpG2TgMsioiMingH+mLaPI0vOt0iaDXwbeGdZXdemfx8G5kbEoohYATwJbEA2+fu2wIxUfjdg01RmJXB9Wp8JbLz665M0jOyL5cpU/lyyLyOAnYDLyl5jlyJiakRMjIiJo9ZrrtbqsBEdjN9hOTP+NJx1R7XT2gotLfDR/Z9nwew1ax2eVfDKy608dO8wttvFJ6pXV4/dLP0xmuXHwCzgVz3sU94rt6JsXd3sU/783IjYoZt6O+sqrVZviey1C7gwIr7ZRdlVEf/qLeyg6/eqBXgxteq7MuB6G19c2sqgQVkiX/GamHXncD59+GKWPjuI9ca0A3DPH0aw8bjXaxypdWXEyHba28UrL7eyxpAS2+y8nCvOHF254AAy4EazdIqI5yVdARxE1g0CcA/wWbIW6/7AXRWqmQ4cIukisv7yXcha9guAUZJ2iIh7U7fLeyJibs7wbgN+J+mMiFgsaSQwPCL+2kOZZcDw9NpelvSUpP0i4kpJAraKiIeAu9Nr/DVvdN00veefbeO0ozakVBKlEkz6xIt88MMvc8qRG/KXuUORYMw7VzLllL/VOlTrwsgxqzj2J0/T0pL9ipp+3Qj+fOvatQ6r7tTjaJb+Gmd+OnBE2eMpwPmSjgOeA75cofzVwK5k3SWPAXcARMTKdJL0p6lffhDZL4FcyTwi5kn6NnCzpBZgFVm/ek/J/DfAeZKmAPuSJeqzUz1t6fmHgKOASyUdBUzLE08z2HSL1znrlsf+bfs3fvZ0DaKxop6aP5TD9xhX6zDqWoRor8NkrvC4o7oxcfyQuP+mDWodhhXwkbd318Nm9erW+O3MiJjY2/Lrbj46Jv9yv1z7XvOhs97SsYrwFaBmZgUM2D5zM7Nm42RuZtbgfHMKM7Mm0d9jyPOov1OyZmZ1LALaSy25lkokbSDpT2nKkblp9BuSTpD0jzRtyGxJe1Wqyy1zM7OCqtjN0g4cExGzJA0HZkq6JT13RkSclrciJ3MzswKq2WceEYuARWl9maT5wDt6U5e7WczMCopQrgVYv3MivbQc3F2dkjYGtgb+nDYdkWaWPV/SupVicjI3MyuowERbSzon0kvL1K7qS5P2TQOOThMUng1sBkwga7mfXikmd7OYmRUQUd1x5mlOqWnAJRFxVXaMeLbs+fN4YwbXbjmZm5kVIjpyjFTJVVM2Od8vgfkR8aOy7WNTfzpk93N4pFJdTuZmZgVF9VrmOwFfAB5O90QA+BbwOUkTyGYPWAgcUqkiJ3MzswKqOTdLRNwFXV6B9PuidTmZm5kVEdTlTa6dzM3MCqrHy/mdzM3MCogqngCtJidzM7OC3M1iZtYEqjiapWqczM3MCohwMjczawq+OYWZWRNwn7mZWYMLRMmjWczMGl8dNsydzM3MCvEJUDOzJlGHTXMnczOzghqqZS7pZ/Tw/RMRU/okIjOzOhZAqdRAyRx4oN+iMDNrFAE0Uss8Ii4sfyxprYh4pe9DMjOrb/U4zrziYElJO0iaB8xPj8dLOqvPIzMzq1eRc+lHeUa+/xj4CLAUICIeAib1ZVBmZvVLRORb+lOu0SwR8bfsvqP/0tE34ZiZNYA67GbJk8z/JmlHICStAUwhdbmYmQ04AVGHo1nydLMcChwOvAP4BzAhPTYzG6CUc+k/FVvmEbEE2L8fYjEzawx12M2SZzTLppKuk/ScpMWSfidp0/4IzsysLjXoaJZLgSuAscDbgSuBy/oyKDOzutV50VCepR/lSeaKiIsjoj0tv6Yuf2SYmfWP7NZxlZf+1G0ylzRS0kjgT5KOl7SxpI0kfQO4of9CNDOrMyXlWyqQtIGkP0maL2mupKPS9pGSbpH0ePp33Up19XQCdCZZC7wzokPKngvg+xUjNTNrQqpeq7sdOCYiZkkaDsyUdAvwJeC2iDhZ0vHA8cB/9VRRT3OzbFK1cM3MmkUVT25GxCJgUVpfJmk+2TDwTwKT024XArfT22ReTtKWwBbAkLIgLioYt5lZEyh0cnN9SeUz0E6NiKld1iptDGwN/BkYkxI9EbFI0uhKB6qYzCV9l+wbYgvg98BHgbsAJ3MzG5jyt8yXRMTESjtJGgZMA46OiJdXmz4llzyjWfYFdgP+GRFfBsYDgwsfycysWZRyLjlIaiNL5JdExFVp87OSxqbnxwKLK9WTJ5m/FhEloF3S2qlSXzRkZgNTFceZK2uC/xKYHxE/KnvqWuCAtH4A8LtKdeXpM39A0jrAeWQjXJYD9+coZ2bWlKo4mmUn4AvAw5Jmp23fAk4GrpB0EPA0sF+livLMzfLVtHqOpBuBtSNiTq/CNjNrBtUbzXIX3c/ItVuRunq6ofM2PT0XEbOKHMjMzPpOTy3z03t4LoBdqxzLgPfo06PY+fBDKu9odWO9OxbWOgQrqgr3SatiN0vV9HTR0C79GYiZWUMIcl2q399yXTRkZmZlGqllbmZmXWuobhYzM+tGHSbzPHcakqTPS/pOeryhpO37PjQzszrVoHcaOgvYAfhcerwMOLPPIjIzq2OK/Et/ytPN8oGI2EbSgwAR8YKkNfo4LjOz+tWgo1lWSWol/WiQNIrcU8iYmTWfejwBmqeb5afA1cBoSSeRTX/7//s0KjOzelaHfeZ55ma5RNJMsnkCBOwdEfP7PDIzs3pUg/7wPPLcnGJD4FXguvJtEfF0XwZmZla3GjGZAzfwxo2dhwCbAAuA9/VhXGZmdUt1eNYwTzfL+8sfp9kUPRuUmVkdKXwFaETMkrRdXwRjZtYQGrGbRdLXyx62ANsAz/VZRGZm9axRT4ACw8vW28n60Kf1TThmZg2g0ZJ5ulhoWEQc10/xmJnVv0ZK5pIGRUR7T7ePMzMbaETjjWa5n6x/fLaka4ErgVc6n4yIq/o4NjOz+tPAfeYjgaVk9/zsHG8egJO5mQ1MDZbMR6eRLI/wRhLvVIcvxcysn9RhBuwpmbcCw3hzEu9Uhy/FzKx/NFo3y6KI+F6/RWJm1ijqMJn3NAVu/c2+bmZWa5GNZsmzVCLpfEmLJT1Stu0ESf+QNDste+UJq6dkvlueCszMBpzqzWd+AbBnF9vPiIgJafl9noq67WaJiOdzhWJmNsBUq888IqZL2rgadeW505CZmZXL3zJfX9IDZcvBOY9whKQ5qRtm3TwFnMzNzIrIm8izZL4kIiaWLVNzHOFsYDNgArAIOD1PWIWnwDUzG8hE3w5NjIhn/3Us6Tzg+jzl3DI3MytIkW/pVd3S2LKH+5BduFmRW+ZmZkVVqWUu6TJgMlnf+t+B7wKTJU1IR1lIzju7OZmbmRVVvdEsn+ti8y97U5eTuZlZEQ08a6KZmZVzMjcza3yNdnMKMzPrgrtZzMwaXf55V/qVk7mZWVFO5mZmja2vrwDtLSdzM7OCVKq/bO5kbmZWhPvMzcyag7tZzMyagZO5mVnjc8vczKwZOJmbmTW48OX8ZmYNz+PMzcyaRdRfNncyNzMryC1za3rHf/52dtzyaV5YNpQDTtoPgMlbP8mBH5vJRmNe4OBT92HB06NqHKWVKy3uYNVJLxHPl6AFBn1iTQbtuyYA7dNepf3qV6EVWj84mLbDhtc42jpQpxcNNcQNnSV1SJotaa6khyR9XVJNYpe0jqSv9qLcCZKO7YuY6skf7hvHsWfu9aZtTz2zLv899cM89MTYbkpZLakV2g4fzpCL12fw2SNpv/pVSgvb6Zi1ko67VzD4/PUYcuH6DPrsWrUOtW6olG/pT43SMn8tIiYASBoNXAqMILv5aX9bB/gqcFYNjl33HnpiLG8buexN2/767Lo1isby0HqtaL3WbH3NFrTRIOK5Djquf41B/7kmWkPZc+s2RNuvX9TjaJaG+3QiYjFwMHCEMq2STpU0Q9IcSYcASJos6Q5JV0h6TNLJkvaXdL+khyVtlvYbJWlaKj9D0k5p+wmSzpd0u6QnJU1JIZwMbJZ+KZya9j2u7PgndsYq6b8lLZB0KzCuH98ms14pLeogHl9FyxZtlP7eQWnOKl4/dCkrpjxPaf6qWodXH4LsBGiepR81Ssv8TSLiydTNMhr4JPBSRGwnaTBwt6Sb067jgfcCzwNPAr+IiO0lHQUcCRwN/AQ4IyLukrQhcFMqA7A5sAswHFgg6WzgeGDLsl8KewDvBrYnG7V0raRJwCvAZ4Gtyd7nWcDM1V+LpIPJvpxYY+g61XqLzAqLV0us/M6LtB05HK3VAh1BLCsx+OyRxKPtrDzhRQb/Zn0k1TrUmvMJ0Orq/IvaA9hK0r7p8Qiy5LoSmBERiwAk/QXoTPIPkyVpgN2BLcr+QNeW1HmW54aIWAGskLQYGNNFHHuk5cH0eFg6/nDg6oh4NR3/2q5eRERMBaYCDFt3gzr8E7GBINqDld95idbdh9A6aQgAGtVK66QhSELvbYMWwUsB6ziZ1+MJ0IZM5pI2BTqAxWRJ/ciIuGm1fSYDK8o2lcoel3jjtbcAO0TEa6uVZ7XyHXT9fgn4n4g4d7XyR1OXH7nZm0UEq374MtpoEG2feeMkZ+uHBlOatZLWrdeg9Ld2WBUwwom8Xi8aarg+c0mjgHOAn0dEkHWLHCapLT3/HklFTrvfDBxRVv+ECvsvI2t1d7oJOFDSsFT+Hekk7XRgH0lDU0v/EwVialjf/fJtnHPsNWw45kWm/eASPrbDo+w8/imm/eAS3rfJs5xy2I2cfvjvax2mlSk9vIqOm1+nNGslrx+0lNcPWkrHfSto3WsosaiD17+0hJUnvkTbt0a4iwUgApXyLf2pUVrmQyXNBtqAduBi4EfpuV8AGwOzlP2lPQfsXaDuKcCZkuaQvR/TgUO72zkilkq6W9IjwB8i4jhJ7wXuTX/oy4HPR8QsSZcDs4G/AncWiKlhnfir3brcfudDm/RzJJZX61ZrMPSOrnoQYY1vj+jnaBpElfK0pPOBjwOLI2LLtG0kcDlZXlsIfDoiXqhYV9ThZakD1bB1N4jxux5V6zCsgPWOXljrEKyg6yadOTMiJva2/PB13hnb7Jzv/9Pp13+jx2OlwRLLgYvKkvkpwPMRcbKk44F1I+K/Kh2r4bpZzMxqKoBS5FsqVRUxnWy0XblPAhem9QvJ2dPQKN0sZmb1I3+HxvqSHih7PDWNYOvJmM5ReBGxKJ2Dq8jJ3MysoAKjWZa8lS6dIpzMzcwK6uORKs9KGpta5WPJhmBX5D5zM7MiosDSO9cCB6T1A4Df5SnklrmZWQHZRUPVaZlLugyYTNa3/neyyQNPBq6QdBDwNLBfnrqczM3MiqrSrIkR8blunur6go0eOJmbmRVUrZZ5NTmZm5kVUad3GnIyNzMrpP/nXcnDydzMrCh3s5iZNbioz9vGOZmbmRXllrmZWROov1zuZG5mVpRK9dfP4mRuZlZEULWLhqrJydzMrAARvmjIzKwpOJmbmTUBJ3MzswbnPnMzs+bg0SxmZg0v3M1iZtbwAidzM7OmUH+9LE7mZmZFeZy5mVkzcDI3M2twEdBRf/0sTuZmZkW5ZW5m1gSczM3MGlwAvgeomVmjCwj3mZuZNbbAJ0DNzJpCFfvMJS0ElgEdQHtETOxNPU7mZmZFVf8E6C4RseStVOBkbmZWSH1OtNVS6wDMzBpKAKVSvgXWl/RA2XJwNzXeLGlmN8/n4pa5mVlR+VvmS3L0ge8UEc9IGg3cIunRiJheNCS3zM3MCkmX8+dZ8tQW8Uz6dzFwNbB9b6JyMjczKyIgopRrqUTSWpKGd64DewCP9CYsd7OYmRVVvStAxwBXS4IsH18aETf2piInczOzoqo0miUingTGV6MuJ3MzsyIiOkeq1BUnczOzoupwnLmTuZlZIUF0dNQ6iH/jZG5mVoSnwDUzaxKeAtfMrLEFEG6Zm5k1uPDNKczMmkI9ngBV1OEQm4FK0nPAX2sdRx9YH3hLczVbv2vmz2yjiBjV28KSbiR7f/JYEhF79vZYRTiZW5+T9EBv755iteHPrPF4oi0zsybgZG5m1gSczK0/TK11AFaYP7MG4z5zM7Mm4Ja5mVkTcDI3M2sCTuZNTFJIOr3s8bGSTqhS3RdI2rcadfXi2HtL2qIX5Zb3RTz9TVKHpNmS5kp6SNLXJdXk/2VJ60j6ai/KnSDp2L6IaaByMm9uK4BPScp7gUOj2BsonMybyGsRMSEi3gd8GNgL+G6NYlkHKJzMrfqczJtbO9mohK+t/oSkjSTdJmlO+nfDtP0CST+VdI+kJztb38r8XNI8STcAo8vq2lbSHZJmSrpJ0ti0/XZJZ0iaLmm+pO0kXSXpcUk/KCv/eUn3p9bmuZJa0/blkk5Krc/7JI2RtCPwH8Cpaf/N0nJjOv6dkjZP5TeRdK+kGZK+32fvcg2lO7ofDByRPqNWSaem1zxH0iEAkianz+gKSY9JOlnS/ul9f1jSZmm/UZKmpfIzJO2Utp8g6fz0mT4paUoK4WRgs/RZnJr2Pa7s+Cd2xirpvyUtkHQrMK4f36aBISK8NOkCLAfWBhYCI4BjgRPSc9cBB6T1A4Fr0voFwJVkX/RbAE+k7Z8CbgFagbcDLwL7Am3APcCotN9ngPPT+u3AD9P6UcAzwFhgMPB3YD3gvSmWtrTfWcAX03oAn0jrpwDfLotx37LXeRvw7rT+AeCPaf3asroOB5bX+jOp1ufaxbYXyG4OfHDZ+zQYeADYBJicPrPO9/8fwIlln82P0/qlwIfS+obA/LR+QvqcB5Ndyr40ffYbA4+UxbEHWQNC6W/oemASsC3wMLBm+pt8Aji21u9lMy2eaKvJRcTLki4CpgCvlT21A1mCBriYLFl2uiYiSsA8SWPStknAZRHRATwj6Y9p+zhgS+CWdIfxVmBRWV3Xpn8fBuZGxCIASU8CGwAfIvsffUYqPxRYnMqsJEsGADPJuhTeRNIwYEfgylQesoQDsBPwf8te4w9XL99EOl/8HsBWZeczRgDvJnsvZ5S9/38Bbk77PAzsktZ3B7Yoey/XljQ8rd8QESuAFZIWk315rG6PtDyYHg9Lxx8OXB0Rr6bjX9tFWXsLnMwHhh8Ds4Bf9bBP+QUHK8rW1c0+5c/PjYgduqm3s67SavWWyP7+BFwYEd/souyqSM09oIOu/15bgBcjYkI3x2/6CykkbUr2/iwmez+PjIibVttnMv/+/pd/Np3vbQuwQ0SUf/GTknt5+e4+DwH/ExHnrlb+aAbAZ1FL7jMfACLieeAK4KCyzfcAn03r+wN3VahmOvDZ1Cc7ljdacguAUZJ2AJDUJul9BcK7DdhX0uhUfqSkjSqUWUbW0iMiXgaekrRfKi9J49N+d/Pm19h0JI0CzgF+nr74bgIOk9SWnn+PpLUKVHkzcERZ/d19SXb612eR3AQcmH4xIekd6bOdDuwjaWhq6X+iQEyWg5P5wHE6b562cwrwZUlzgC+Q9Zv25GrgcbKf5GcDdwBExEqyvvMfSnoImE3W7ZFLRMwDvg3cnGK5haxftye/AY6T9GA6cbc/cFA6/lzgk2m/o4DDJc0g625oFkPTCce5wK1kCbjzROMvgHnALEmPAOdS7Bf4FGBiOnk5Dzi0p50jYilwt6RHJJ0aETeT9bvfK+lh4LfA8IiYBVxO9vcxDbizQEyWgy/nNzNrAm6Zm5k1ASdzM7Mm4GRuZtYEnMzNzJqAk7mZWRNwMreGojdmDHxE0pWS1nwLdf1r5kdJv1APMzGmuU1yD7ksK7dQXUx01t321fYpNMujPBPhgOZkbo2mc8bALckuUX/TOGilSbqKioj/l8a8d2cyBcbPm/U3J3NrZHcC70qt5j9JuhR4uIeZA6XuZ368XdLEtL6npFnKZmu8TdLGZF8aX0u/CnbuYXbB9STdnC5oOpc3T4fQJUnXKJvxca6kg1d77vQUy23pak/UzSyRNrB5bhZrSJIGAR8Fbkybtge2jIinUkJ8KSK2kzSY7ArFm4GtySYGez/ZJFHzgPNXq3cUcB4wKdU1MiKel3QO2WyFp6X9LgXOiIi7lE0ffBPZDJDfBe6KiO9J+hjZLIaVHJiOMZRswrFp6crKtYBZEXGMpO+kuo8gm5Xw0Ih4XNIHyGaa3LUXb6M1ESdzazRDJc1O63cCvyTr/rg/Ip5K27ubObC7mR/LfRCY3llXmtemK93NLjiJNBtlRNwg6YUcr2mKpH3S+gYp1qVkE2Bdnrb/GrhKPc8SaQOYk7k1mtdWnyExJbVXyjfR9cyBe1F55j7l2Ad6nl0w9xwZymYz3D3V9aqk24Eh3eweVJ4l0gYo95lbM+pu5sDuZn4sdy/wfyRtksqOTNtXnx2wu9kFp5NmaJT0UWDdCrGOAF5IiXxzsl8GnVrIJjED+E+y7pueZom0AczJ3JpRdzMHdjnzY7mIeI6sn/uqNAtjZzfHdWRTuM6WtDPdzy54IjBJ0iyy7p6nK8R6IzBI2YyR3wfuK3vuFeB9kmaS9Yl/L23vbpZIG8A8a6KZWRNwy9zMrAk4mZuZNQEnczOzJuBkbmbWBJzMzcyagJO5mVkTcDI3M2sC/wv5xy0OHce1KwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# n_estimators(M) --> the number of trees in the forest\n", + "# max_features(d) --> the number of features to consider when looking for the best split\n", + "# max_depth(m) --> the maximum depth of the tree.\n", + "\n", + "rfc = RandomForestClassifier(random_state=42)\n", + "rfc.fit(X_trainval_scaled, y_trainval)\n", + "\n", + "# for test there are 94 cases\n", + "plot_confusion_matrix(rfc, \n", + " X_test_scaled, \n", + " y_test, \n", + " values_format='d', \n", + " display_labels=['Nondemented', 'Demented'])" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:25:06.524845Z", + "iopub.status.busy": "2021-02-21T05:25:06.524139Z", + "iopub.status.idle": "2021-02-21T05:25:06.567152Z", + "shell.execute_reply": "2021-02-21T05:25:06.566168Z" + }, + "papermill": { + "duration": 0.119008, + "end_time": "2021-02-21T05:25:06.567316", + "exception": false, + "start_time": "2021-02-21T05:25:06.448308", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train accuracy 1.0\n", + "Test accuracy 0.8133333333333334\n", + "Test recall 0.7027027027027027\n", + "Test AUC 0.811877667140825\n" + ] + } + ], + "source": [ + "train_score = 0\n", + "test_score = 0\n", + "test_recall = 0\n", + "test_auc = 0\n", + "\n", + "train_score = rfc.score(X_trainval_scaled, y_trainval)\n", + "test_score = rfc.score(X_test_scaled, y_test)\n", + "y_predict = rfc.predict(X_test_scaled)\n", + "test_recall = recall_score(y_test, y_predict)\n", + "fpr, tpr, thresholds = roc_curve(y_test, y_predict)\n", + "test_auc = auc(fpr, tpr)\n", + "\n", + "print(\"Train accuracy \", train_score)\n", + "print(\"Test accuracy \", test_score)\n", + "print(\"Test recall\", test_recall)\n", + "print(\"Test AUC\", test_auc)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.066938, + "end_time": "2021-02-21T05:25:06.702778", + "exception": false, + "start_time": "2021-02-21T05:25:06.635840", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### Optimize parameters(Finetuning) --> GridSearchCV()" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:25:06.846366Z", + "iopub.status.busy": "2021-02-21T05:25:06.845356Z", + "iopub.status.idle": "2021-02-21T05:25:06.848964Z", + "shell.execute_reply": "2021-02-21T05:25:06.848404Z" + }, + "papermill": { + "duration": 0.078921, + "end_time": "2021-02-21T05:25:06.849108", + "exception": false, + "start_time": "2021-02-21T05:25:06.770187", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# Number of trees in random forest\n", + "n_estimators = [int(x) for x in np.linspace(start = 10, stop = 100, num = 10)]\n", + "\n", + "# Number of features to consider at every split\n", + "max_features = ['auto', 'sqrt', 'log2']\n", + "\n", + "# Maximum number of levels in tree\n", + "max_depth = range(1,10)\n", + "\n", + "# measure the quality of a split\n", + "criterion = ['gini']\n", + "\n", + "# Method of selecting samples for training each tree\n", + "bootstrap = [True, False]" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:25:06.987850Z", + "iopub.status.busy": "2021-02-21T05:25:06.987210Z", + "iopub.status.idle": "2021-02-21T05:25:06.992272Z", + "shell.execute_reply": "2021-02-21T05:25:06.991634Z" + }, + "papermill": { + "duration": 0.07594, + "end_time": "2021-02-21T05:25:06.992384", + "exception": false, + "start_time": "2021-02-21T05:25:06.916444", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# Create the param grid\n", + "param_grid = {'n_estimators': n_estimators,\n", + " 'max_features': max_features,\n", + " 'max_depth': max_depth,\n", + " 'criterion': criterion,\n", + " 'bootstrap': bootstrap}" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:25:07.136382Z", + "iopub.status.busy": "2021-02-21T05:25:07.135353Z", + "iopub.status.idle": "2021-02-21T05:26:57.959858Z", + "shell.execute_reply": "2021-02-21T05:26:57.960433Z" + }, + "papermill": { + "duration": 110.901375, + "end_time": "2021-02-21T05:26:57.960599", + "exception": false, + "start_time": "2021-02-21T05:25:07.059224", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'bootstrap': False, 'criterion': 'gini', 'max_depth': 8, 'max_features': 'auto', 'n_estimators': 60}\n" + ] + } + ], + "source": [ + "optimal_params = GridSearchCV(RandomForestClassifier(),\n", + " param_grid,\n", + " cv=5, # we are taking 5-fold as in k-fold cross validation\n", + " scoring='accuracy', # try the other scoring if have time\n", + " verbose=0,\n", + " n_jobs=-1)\n", + "\n", + "optimal_params.fit(X_trainval_scaled, y_trainval)\n", + "print(optimal_params.best_params_)" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:26:58.102512Z", + "iopub.status.busy": "2021-02-21T05:26:58.101661Z", + "iopub.status.idle": "2021-02-21T05:26:58.104540Z", + "shell.execute_reply": "2021-02-21T05:26:58.104022Z" + }, + "papermill": { + "duration": 0.076703, + "end_time": "2021-02-21T05:26:58.104654", + "exception": false, + "start_time": "2021-02-21T05:26:58.027951", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "bootstrap = optimal_params.best_params_['bootstrap']\n", + "criterion = optimal_params.best_params_['criterion']\n", + "max_depth = optimal_params.best_params_['max_depth']\n", + "max_features = optimal_params.best_params_['max_features']\n", + "n_estimators = optimal_params.best_params_['n_estimators']" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:26:58.264473Z", + "iopub.status.busy": "2021-02-21T05:26:58.263365Z", + "iopub.status.idle": "2021-02-21T05:26:58.573331Z", + "shell.execute_reply": "2021-02-21T05:26:58.572415Z" + }, + "papermill": { + "duration": 0.401085, + "end_time": "2021-02-21T05:26:58.573516", + "exception": false, + "start_time": "2021-02-21T05:26:58.172431", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "rfc = RandomForestClassifier(n_estimators=n_estimators, \n", + " max_features=max_features, \n", + " max_depth=max_depth, \n", + " criterion=criterion,\n", + " bootstrap=bootstrap,\n", + " random_state=42)\n", + "\n", + "rfc.fit(X_trainval_scaled, y_trainval)\n", + "\n", + "# for test there are 94 cases\n", + "plot_confusion_matrix(rfc, \n", + " X_test_scaled, \n", + " y_test, \n", + " values_format='d', \n", + " display_labels=['Nondemented', 'Demented'])" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:26:58.733907Z", + "iopub.status.busy": "2021-02-21T05:26:58.732790Z", + "iopub.status.idle": "2021-02-21T05:26:58.764470Z", + "shell.execute_reply": "2021-02-21T05:26:58.765236Z" + }, + "papermill": { + "duration": 0.111182, + "end_time": "2021-02-21T05:26:58.765445", + "exception": false, + "start_time": "2021-02-21T05:26:58.654263", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train accuracy 0.9966442953020134\n", + "Test accuracy 0.7866666666666666\n", + "Test recall 0.7027027027027027\n", + "Test AUC 0.7855618776671409\n" + ] + } + ], + "source": [ + "train_score = 0\n", + "test_score = 0\n", + "test_recall = 0\n", + "test_auc = 0\n", + "\n", + "train_score = rfc.score(X_trainval_scaled, y_trainval)\n", + "test_score = rfc.score(X_test_scaled, y_test)\n", + "y_predict = rfc.predict(X_test_scaled)\n", + "test_recall = recall_score(y_test, y_predict)\n", + "rfc_fpr, rfc_tpr, thresholds = roc_curve(y_test, y_predict)\n", + "test_auc = auc(rfc_fpr, rfc_tpr)\n", + "\n", + "print(\"Train accuracy \", train_score)\n", + "print(\"Test accuracy \", test_score)\n", + "print(\"Test recall\", test_recall)\n", + "print(\"Test AUC\", test_auc)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.071036, + "end_time": "2021-02-21T05:26:58.907460", + "exception": false, + "start_time": "2021-02-21T05:26:58.836424", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Logistic Regression" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:26:59.053932Z", + "iopub.status.busy": "2021-02-21T05:26:59.052923Z", + "iopub.status.idle": "2021-02-21T05:26:59.056562Z", + "shell.execute_reply": "2021-02-21T05:26:59.055881Z" + }, + "papermill": { + "duration": 0.078846, + "end_time": "2021-02-21T05:26:59.056676", + "exception": false, + "start_time": "2021-02-21T05:26:58.977830", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "from sklearn.linear_model import LogisticRegression\n", + "from sklearn.tree import DecisionTreeClassifier\n", + "from sklearn.metrics import confusion_matrix, precision_score, accuracy_score, recall_score, roc_curve, auc" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:26:59.210264Z", + "iopub.status.busy": "2021-02-21T05:26:59.209530Z", + "iopub.status.idle": "2021-02-21T05:26:59.392907Z", + "shell.execute_reply": "2021-02-21T05:26:59.392347Z" + }, + "papermill": { + "duration": 0.263735, + "end_time": "2021-02-21T05:26:59.393060", + "exception": false, + "start_time": "2021-02-21T05:26:59.129325", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "log_reg_model = LogisticRegression().fit(X_trainval_scaled, y_trainval)\n", + "\n", + "# for test there are 94 cases\n", + "plot_confusion_matrix(log_reg_model, \n", + " X_test_scaled, \n", + " y_test, \n", + " values_format='d', \n", + " display_labels=['Nondemented', 'Demented'])" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:26:59.553251Z", + "iopub.status.busy": "2021-02-21T05:26:59.545050Z", + "iopub.status.idle": "2021-02-21T05:26:59.562200Z", + "shell.execute_reply": "2021-02-21T05:26:59.562678Z" + }, + "papermill": { + "duration": 0.098292, + "end_time": "2021-02-21T05:26:59.562830", + "exception": false, + "start_time": "2021-02-21T05:26:59.464538", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train accuracy 0.8221476510067114\n", + "Test accuracy 0.7466666666666667\n", + "Test recall 0.7027027027027027\n", + "Test AUC 0.7460881934566145\n" + ] + } + ], + "source": [ + "train_score = 0\n", + "test_score = 0\n", + "test_recall = 0\n", + "test_auc = 0\n", + "\n", + "log_reg_model = LogisticRegression().fit(X_trainval_scaled, y_trainval)\n", + "train_score = log_reg_model.score(X_trainval_scaled, y_trainval)\n", + "test_score = log_reg_model.score(X_test_scaled, y_test)\n", + "scores = log_reg_model.score(X_test_scaled, y_test)\n", + "y_predict = log_reg_model.predict(X_test_scaled)\n", + "test_recall = recall_score(y_test, y_predict)\n", + "fpr, tpr, thresholds = roc_curve(y_test, y_predict)\n", + "test_auc = auc(fpr, tpr)\n", + "\n", + "\n", + "print(\"Train accuracy \", train_score)\n", + "print(\"Test accuracy \", test_score)\n", + "print(\"Test recall\", test_recall)\n", + "print(\"Test AUC\", test_auc)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.072092, + "end_time": "2021-02-21T05:26:59.707748", + "exception": false, + "start_time": "2021-02-21T05:26:59.635656", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### Optimize parameters(Finetuning) --> GridSearchCV()" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:26:59.859870Z", + "iopub.status.busy": "2021-02-21T05:26:59.859194Z", + "iopub.status.idle": "2021-02-21T05:27:00.087736Z", + "shell.execute_reply": "2021-02-21T05:27:00.088576Z" + }, + "papermill": { + "duration": 0.308978, + "end_time": "2021-02-21T05:27:00.088785", + "exception": false, + "start_time": "2021-02-21T05:26:59.779807", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'C': 2, 'penalty': 'l2'}\n" + ] + } + ], + "source": [ + "param_grid = {'penalty': ['l1','l2'], \n", + " 'C': [0.001,0.01,0.1,1, 2, 3, 5, 10,100,1000]}\n", + "\n", + "optimal_params = GridSearchCV(LogisticRegression(),\n", + " param_grid,\n", + " cv=5, # we are taking 5-fold as in k-fold cross validation\n", + " scoring='accuracy', # try the other scoring if have time\n", + " verbose=0,\n", + " n_jobs=-1)\n", + "\n", + "optimal_params.fit(X_trainval_scaled, y_trainval)\n", + "print(optimal_params.best_params_)" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:27:00.240392Z", + "iopub.status.busy": "2021-02-21T05:27:00.239384Z", + "iopub.status.idle": "2021-02-21T05:27:00.243000Z", + "shell.execute_reply": "2021-02-21T05:27:00.242454Z" + }, + "papermill": { + "duration": 0.080874, + "end_time": "2021-02-21T05:27:00.243134", + "exception": false, + "start_time": "2021-02-21T05:27:00.162260", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# best_score = -10\n", + "# for c in range(1, 20): \n", + "# log_reg_model = LogisticRegression(C=c)\n", + "# scores = cross_val_score(log_reg_model, X_trainval_scaled, y_trainval, cv=5, scoring='accuracy')\n", + " \n", + "# mean_score = scores.mean()\n", + " \n", + "# if mean_score > best_score:\n", + "# best_score = mean_score\n", + "# best_c = c\n", + "# print(best_c)" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:27:00.393338Z", + "iopub.status.busy": "2021-02-21T05:27:00.392646Z", + "iopub.status.idle": "2021-02-21T05:27:00.394698Z", + "shell.execute_reply": "2021-02-21T05:27:00.395162Z" + }, + "papermill": { + "duration": 0.079368, + "end_time": "2021-02-21T05:27:00.395308", + "exception": false, + "start_time": "2021-02-21T05:27:00.315940", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "best_C = optimal_params.best_params_['C']\n", + "best_penalty = optimal_params.best_params_['penalty']" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:27:00.543733Z", + "iopub.status.busy": "2021-02-21T05:27:00.543072Z", + "iopub.status.idle": "2021-02-21T05:27:00.726618Z", + "shell.execute_reply": "2021-02-21T05:27:00.725951Z" + }, + "papermill": { + "duration": 0.258901, + "end_time": "2021-02-21T05:27:00.726760", + "exception": false, + "start_time": "2021-02-21T05:27:00.467859", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "log_reg_model = LogisticRegression(C=best_C, penalty=best_penalty).fit(X_trainval_scaled, y_trainval)\n", + "\n", + "# for test there are 94 cases\n", + "plot_confusion_matrix(log_reg_model, \n", + " X_test_scaled, \n", + " y_test, \n", + " values_format='d', \n", + " display_labels=['Nondemented', 'Demented'])" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:27:00.902586Z", + "iopub.status.busy": "2021-02-21T05:27:00.887719Z", + "iopub.status.idle": "2021-02-21T05:27:00.911752Z", + "shell.execute_reply": "2021-02-21T05:27:00.912556Z" + }, + "papermill": { + "duration": 0.108288, + "end_time": "2021-02-21T05:27:00.912736", + "exception": false, + "start_time": "2021-02-21T05:27:00.804448", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train accuracy with Logistec regression: 0.8288590604026845\n", + "Test accuracy with Logistec regression: 0.7466666666666667\n", + "Test recall with Logistec regression: 0.7027027027027027\n", + "Test AUC with Logistec regression: 0.7460881934566145\n" + ] + } + ], + "source": [ + "train_score = 0\n", + "test_score = 0\n", + "test_recall = 0\n", + "test_auc = 0\n", + "\n", + "best_log_reg_model = LogisticRegression(C=best_C, penalty=best_penalty).fit(X_trainval_scaled, y_trainval)\n", + "train_score = best_log_reg_model.score(X_trainval_scaled, y_trainval)\n", + "test_score = best_log_reg_model.score(X_test_scaled, y_test)\n", + "y_predict = best_log_reg_model.predict(X_test_scaled)\n", + "test_recall = recall_score(y_test, y_predict)\n", + "lgr_fpr, lgr_tpr, thresholds = roc_curve(y_test, y_predict)\n", + "test_auc = auc(lgr_fpr, lgr_tpr)\n", + "\n", + "print(\"Train accuracy with Logistec regression:\", train_score)\n", + "print(\"Test accuracy with Logistec regression:\", test_score)\n", + "print(\"Test recall with Logistec regression:\", test_recall)\n", + "print(\"Test AUC with Logistec regression:\", test_auc)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.074221, + "end_time": "2021-02-21T05:27:01.064257", + "exception": false, + "start_time": "2021-02-21T05:27:00.990036", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Decision Tree" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:27:01.218547Z", + "iopub.status.busy": "2021-02-21T05:27:01.217821Z", + "iopub.status.idle": "2021-02-21T05:27:01.391666Z", + "shell.execute_reply": "2021-02-21T05:27:01.391137Z" + }, + "papermill": { + "duration": 0.253607, + "end_time": "2021-02-21T05:27:01.391785", + "exception": false, + "start_time": "2021-02-21T05:27:01.138178", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 48, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "dt_model = DecisionTreeClassifier().fit(X_trainval_scaled, y_trainval)\n", + "\n", + "# for test there are 94 cases\n", + "plot_confusion_matrix(dt_model, \n", + " X_test_scaled, \n", + " y_test, \n", + " values_format='d', \n", + " display_labels=['Nondemented', 'Demented'])" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:27:01.552124Z", + "iopub.status.busy": "2021-02-21T05:27:01.551400Z", + "iopub.status.idle": "2021-02-21T05:27:01.561205Z", + "shell.execute_reply": "2021-02-21T05:27:01.560544Z" + }, + "papermill": { + "duration": 0.094559, + "end_time": "2021-02-21T05:27:01.561320", + "exception": false, + "start_time": "2021-02-21T05:27:01.466761", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train accuracy with DecisionTreeClassifier: 1.0\n", + "Test accuracy with DecisionTreeClassifier: 0.72\n", + "Test recall with DecisionTreeClassifier: 0.6756756756756757\n", + "Test AUC with DecisionTreeClassifier: 0.719416785206259\n" + ] + } + ], + "source": [ + "train_score = 0\n", + "test_score = 0\n", + "test_recall = 0\n", + "test_auc = 0\n", + "\n", + "dt_model = DecisionTreeClassifier().fit(X_trainval_scaled, y_trainval)\n", + "train_score = dt_model.score(X_trainval_scaled, y_trainval)\n", + "test_score = dt_model.score(X_test_scaled, y_test)\n", + "y_predict = dt_model.predict(X_test_scaled)\n", + "test_recall = recall_score(y_test, y_predict)\n", + "fpr, tpr, thresholds = roc_curve(y_test, y_predict)\n", + "test_auc = auc(fpr, tpr)\n", + "\n", + "print(\"Train accuracy with DecisionTreeClassifier:\", train_score)\n", + "print(\"Test accuracy with DecisionTreeClassifier:\", test_score)\n", + "print(\"Test recall with DecisionTreeClassifier:\", test_recall)\n", + "print(\"Test AUC with DecisionTreeClassifier:\", test_auc)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.075289, + "end_time": "2021-02-21T05:27:01.712002", + "exception": false, + "start_time": "2021-02-21T05:27:01.636713", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### Optimize parameters(Finetuning) --> GridSearchCV()" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:27:01.872535Z", + "iopub.status.busy": "2021-02-21T05:27:01.871460Z", + "iopub.status.idle": "2021-02-21T05:27:01.957853Z", + "shell.execute_reply": "2021-02-21T05:27:01.957187Z" + }, + "papermill": { + "duration": 0.169737, + "end_time": "2021-02-21T05:27:01.957989", + "exception": false, + "start_time": "2021-02-21T05:27:01.788252", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'criterion': 'gini', 'max_depth': 2}\n" + ] + } + ], + "source": [ + "param_grid = {'criterion': ['gini'], \n", + " 'max_depth': range(1,10)}\n", + "\n", + "optimal_params = GridSearchCV(DecisionTreeClassifier(),\n", + " param_grid,\n", + " cv=5, # we are taking 5-fold as in k-fold cross validation\n", + " scoring='accuracy', # try the other scoring if have time\n", + " verbose=0,\n", + " n_jobs=-1)\n", + "\n", + "optimal_params.fit(X_trainval_scaled, y_trainval)\n", + "print(optimal_params.best_params_)" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:27:02.116441Z", + "iopub.status.busy": "2021-02-21T05:27:02.115645Z", + "iopub.status.idle": "2021-02-21T05:27:02.118675Z", + "shell.execute_reply": "2021-02-21T05:27:02.118177Z" + }, + "papermill": { + "duration": 0.085294, + "end_time": "2021-02-21T05:27:02.118786", + "exception": false, + "start_time": "2021-02-21T05:27:02.033492", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "criterion = optimal_params.best_params_['criterion']\n", + "max_depth = optimal_params.best_params_['max_depth']" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:27:02.276080Z", + "iopub.status.busy": "2021-02-21T05:27:02.275361Z", + "iopub.status.idle": "2021-02-21T05:27:02.278765Z", + "shell.execute_reply": "2021-02-21T05:27:02.278075Z" + }, + "papermill": { + "duration": 0.08369, + "end_time": "2021-02-21T05:27:02.278885", + "exception": false, + "start_time": "2021-02-21T05:27:02.195195", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# best_score = -1\n", + "# for d in range(1, 25): \n", + "# dt_model = DecisionTreeClassifier(max_depth = d)\n", + "# scores = cross_val_score(dt_model, X_trainval_scaled, y_trainval, cv=5, scoring='accuracy')\n", + " \n", + "# mean_score = scores.mean()\n", + " \n", + "# if mean_score > best_score:\n", + "# best_score = mean_score\n", + "# best_d = d\n", + "# print(best_d)" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:27:02.439628Z", + "iopub.status.busy": "2021-02-21T05:27:02.438933Z", + "iopub.status.idle": "2021-02-21T05:27:02.607174Z", + "shell.execute_reply": "2021-02-21T05:27:02.606465Z" + }, + "papermill": { + "duration": 0.251238, + "end_time": "2021-02-21T05:27:02.607295", + "exception": false, + "start_time": "2021-02-21T05:27:02.356057", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "dt_model = DecisionTreeClassifier(criterion=criterion, max_depth=max_depth).fit(X_trainval_scaled, y_trainval)\n", + "\n", + "# for test there are 94 cases\n", + "plot_confusion_matrix(dt_model, \n", + " X_test_scaled, \n", + " y_test, \n", + " values_format='d', \n", + " display_labels=['Nondemented', 'Demented'])" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:27:02.771744Z", + "iopub.status.busy": "2021-02-21T05:27:02.771062Z", + "iopub.status.idle": "2021-02-21T05:27:02.780168Z", + "shell.execute_reply": "2021-02-21T05:27:02.780646Z" + }, + "papermill": { + "duration": 0.096343, + "end_time": "2021-02-21T05:27:02.780797", + "exception": false, + "start_time": "2021-02-21T05:27:02.684454", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train accuracy with DecisionTreeClassifier: 0.7751677852348994\n", + "Test accuracy with DecisionTreeClassifier: 0.8\n", + "Test recall with DecisionTreeClassifier: 0.5945945945945946\n", + "Test AUC with DecisionTreeClassifier: 0.7972972972972974\n" + ] + } + ], + "source": [ + "train_score = 0\n", + "test_score = 0\n", + "test_recall = 0\n", + "test_auc = 0\n", + "\n", + "dt_model = DecisionTreeClassifier(criterion=criterion, max_depth=max_depth).fit(X_trainval_scaled, y_trainval)\n", + "train_score = dt_model.score(X_trainval_scaled, y_trainval)\n", + "test_score = dt_model.score(X_test_scaled, y_test)\n", + "y_predict = dt_model.predict(X_test_scaled)\n", + "test_recall = recall_score(y_test, y_predict)\n", + "dt_fpr, dt_tpr, thresholds = roc_curve(y_test, y_predict)\n", + "test_auc = auc(dt_fpr, dt_tpr)\n", + "\n", + "print(\"Train accuracy with DecisionTreeClassifier:\", train_score)\n", + "print(\"Test accuracy with DecisionTreeClassifier:\", test_score)\n", + "print(\"Test recall with DecisionTreeClassifier:\", test_recall)\n", + "print(\"Test AUC with DecisionTreeClassifier:\", test_auc)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.079404, + "end_time": "2021-02-21T05:27:02.938142", + "exception": false, + "start_time": "2021-02-21T05:27:02.858738", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Plot ROC and compare AUC" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:27:03.110051Z", + "iopub.status.busy": "2021-02-21T05:27:03.109234Z", + "iopub.status.idle": "2021-02-21T05:27:03.296587Z", + "shell.execute_reply": "2021-02-21T05:27:03.295893Z" + }, + "papermill": { + "duration": 0.279336, + "end_time": "2021-02-21T05:27:03.296703", + "exception": false, + "start_time": "2021-02-21T05:27:03.017367", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.figure(figsize=(5, 5), dpi=100)\n", + "plt.plot(svm_fpr, svm_tpr, linestyle='-', label='SVM')\n", + "plt.plot(lgr_fpr, lgr_tpr, marker='.', label='Logistic')\n", + "plt.plot(rfc_fpr, rfc_tpr, linestyle=':', label='Random Forest')\n", + "plt.plot(dt_fpr, dt_tpr, linestyle='-.', label='Decision Tree')\n", + "\n", + "plt.xlabel('False Positive Rate')\n", + "plt.ylabel('True Positive Rate')\n", + "\n", + "plt.legend()\n", + "\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + }, + "papermill": { + "duration": 147.570544, + "end_time": "2021-02-21T05:27:03.491568", + "environment_variables": {}, + "exception": null, + "input_path": "__notebook__.ipynb", + "output_path": "__notebook__.ipynb", + "parameters": {}, + "start_time": "2021-02-21T05:24:35.921024", + "version": "2.1.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/models/Dementia Prediction Model/predict.py b/models/Dementia Prediction Model/predict.py new file mode 100644 index 00000000..95721d7b --- /dev/null +++ b/models/Dementia Prediction Model/predict.py @@ -0,0 +1,19 @@ +import joblib +import pandas as pd + +class CarPricePredictor: + def __init__(self, model_path): + self.model = joblib.load(model_path) + + def predict(self, input_data): + return self.model.predict(input_data) + +if __name__ == "__main__": + predictor = CarPricePredictor('saved_models/car_price_model.pkl') + + # Example input data, replace with actual data + # Ensure the input data has the same feature columns as used in training + input_data = pd.DataFrame([[2015, 'Toyota', 'Corolla', 50000, 'Petrol']], + columns=['year', 'company', 'model', 'kms_driven', 'fuel_type']) # Adjust columns accordingly + predictions = predictor.predict(input_data) + print("Predicted Price: ₹", predictions[0]) \ No newline at end of file