From 063ccc3a74afaae7b5d92a80ae6bebd5e6ebbc78 Mon Sep 17 00:00:00 2001 From: Simran Shaikh Date: Fri, 8 Nov 2024 22:34:59 +0530 Subject: [PATCH] added new project #208 --- models/Dementia Prediction Model/Readme.md | 44 + ...-prediction-using-different-ml-model.ipynb | 3276 +++++++++++++++++ 2 files changed, 3320 insertions(+) create mode 100644 models/Dementia Prediction Model/Readme.md create mode 100644 models/Dementia Prediction Model/dementia-prediction-using-different-ml-model.ipynb diff --git a/models/Dementia Prediction Model/Readme.md b/models/Dementia Prediction Model/Readme.md new file mode 100644 index 00000000..d646eca1 --- /dev/null +++ b/models/Dementia Prediction Model/Readme.md @@ -0,0 +1,44 @@ +# Dementia Prediction Model + +This repository contains a machine learning model that predicts dementia in patients based on medical and demographic data. The model utilizes various supervised learning techniques to analyze historical patient data and provides a probability score indicating the likelihood of dementia diagnosis. + +## Table of Contents +- [Introduction](#introduction) +- [Problem Statement](#problem-statement) +- [Solution Overview](#solution-overview) +- [Data](#data) + + +## Introduction + +Dementia is a chronic condition that affects cognitive function, leading to memory loss and impaired reasoning. Early prediction of dementia can facilitate timely intervention and improve patient care. This project uses machine learning models to predict dementia based on patient data, such as age, cognitive test results, and medical history. + +## Problem Statement + +Dementia is challenging to predict due to its complex and progressive nature. Several factors contribute to the risk of dementia, including: +- **Demographic Information**: Age, gender, education level. +- **Medical History**: Family history of dementia, comorbid conditions. +- **Cognitive Test Scores**: Results from standardized assessments. + +The primary challenge is to develop a model that accurately predicts dementia by identifying patterns in patient data. + +## Solution Overview + +The model is trained using various machine learning algorithms to determine the most effective approach for dementia prediction. Algorithms include: +- **Logistic Regression** +- **Decision Trees** +- **Random Forests** +- **Support Vector Machines (SVM)** + +The model uses supervised learning and is trained on a dataset of patient records containing features such as age, cognitive test scores, and medical history. + +## Data + +The dataset includes: +- **Patient Demographics**: Age, gender, education level. +- **Cognitive Scores**: Scores from memory and cognitive assessments. +- **Medical History**: Previous diagnoses, family history, and comorbid conditions. + +Data should be placed in the `data/` folder in CSV format for model training and evaluation. + + diff --git a/models/Dementia Prediction Model/dementia-prediction-using-different-ml-model.ipynb b/models/Dementia Prediction Model/dementia-prediction-using-different-ml-model.ipynb new file mode 100644 index 00000000..990a4df2 --- /dev/null +++ b/models/Dementia Prediction Model/dementia-prediction-using-different-ml-model.ipynb @@ -0,0 +1,3276 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.045864, + "end_time": "2021-02-21T05:24:40.532878", + "exception": false, + "start_time": "2021-02-21T05:24:40.487014", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Importing Libs" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", + "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", + "execution": { + "iopub.execute_input": "2021-02-21T05:24:40.625317Z", + "iopub.status.busy": "2021-02-21T05:24:40.624657Z", + "iopub.status.idle": "2021-02-21T05:24:41.846215Z", + "shell.execute_reply": "2021-02-21T05:24:41.846810Z" + }, + "papermill": { + "duration": 1.269347, + "end_time": "2021-02-21T05:24:41.847025", + "exception": false, + "start_time": "2021-02-21T05:24:40.577678", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import pandas as pd # used to load, manipulate the data and for one-hot encoding\n", + "import numpy as np # data manipulation\n", + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.colors as colors\n", + "from sklearn.utils import resample # for downsample the dataset\n", + "from sklearn.model_selection import train_test_split # for splitting the dataset into train and test split\n", + "from sklearn.preprocessing import scale # scale and center the data\n", + "from sklearn.svm import SVC # will make a SVM for classification\n", + "from sklearn.model_selection import GridSearchCV # will do the cross validation\n", + "from sklearn.metrics import plot_confusion_matrix # will draw the confusion matrix\n", + "from sklearn.decomposition import PCA # to perform PCA to plot the data\n", + "from sklearn.impute import SimpleImputer\n", + "from sklearn.preprocessing import StandardScaler, MinMaxScaler\n", + "from sklearn.model_selection import cross_val_score\n", + "from sklearn.metrics import confusion_matrix, precision_score, accuracy_score, recall_score, roc_curve, auc\n", + "import seaborn as sns" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.044613, + "end_time": "2021-02-21T05:24:41.936713", + "exception": false, + "start_time": "2021-02-21T05:24:41.892100", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Load the data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0", + "_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a", + "execution": { + "iopub.execute_input": "2021-02-21T05:24:42.029211Z", + "iopub.status.busy": "2021-02-21T05:24:42.028542Z", + "iopub.status.idle": "2021-02-21T05:24:42.058687Z", + "shell.execute_reply": "2021-02-21T05:24:42.059215Z" + }, + "papermill": { + "duration": 0.077355, + "end_time": "2021-02-21T05:24:42.059378", + "exception": false, + "start_time": "2021-02-21T05:24:41.982023", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "data = pd.read_csv(\"../input/mri-and-alzheimers/oasis_longitudinal.csv\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.04491, + "end_time": "2021-02-21T05:24:42.149525", + "exception": false, + "start_time": "2021-02-21T05:24:42.104615", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Explore the data" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:42.244342Z", + "iopub.status.busy": "2021-02-21T05:24:42.243655Z", + "iopub.status.idle": "2021-02-21T05:24:42.249012Z", + "shell.execute_reply": "2021-02-21T05:24:42.248425Z" + }, + "papermill": { + "duration": 0.053142, + "end_time": "2021-02-21T05:24:42.249140", + "exception": false, + "start_time": "2021-02-21T05:24:42.195998", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "pd.set_option('display.max_columns', None) # will show the all columns with pandas dataframe\n", + "pd.set_option('display.max_rows', None) # will show the all rows with pandas dataframe" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:42.359739Z", + "iopub.status.busy": "2021-02-21T05:24:42.359068Z", + "iopub.status.idle": "2021-02-21T05:24:42.371999Z", + "shell.execute_reply": "2021-02-21T05:24:42.372499Z" + }, + "papermill": { + "duration": 0.078228, + "end_time": "2021-02-21T05:24:42.372636", + "exception": false, + "start_time": "2021-02-21T05:24:42.294408", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Subject IDMRI IDGroupVisitMR DelayM/FHandAgeEDUCSESMMSECDReTIVnWBVASF
0OAS2_0001OAS2_0001_MR1Nondemented10MR87142.027.00.019870.6960.883
1OAS2_0001OAS2_0001_MR2Nondemented2457MR88142.030.00.020040.6810.876
2OAS2_0002OAS2_0002_MR1Demented10MR7512NaN23.00.516780.7361.046
3OAS2_0002OAS2_0002_MR2Demented2560MR7612NaN28.00.517380.7131.010
4OAS2_0002OAS2_0002_MR3Demented31895MR8012NaN22.00.516980.7011.034
\n", + "
" + ], + "text/plain": [ + " Subject ID MRI ID Group Visit MR Delay M/F Hand Age EDUC \\\n", + "0 OAS2_0001 OAS2_0001_MR1 Nondemented 1 0 M R 87 14 \n", + "1 OAS2_0001 OAS2_0001_MR2 Nondemented 2 457 M R 88 14 \n", + "2 OAS2_0002 OAS2_0002_MR1 Demented 1 0 M R 75 12 \n", + "3 OAS2_0002 OAS2_0002_MR2 Demented 2 560 M R 76 12 \n", + "4 OAS2_0002 OAS2_0002_MR3 Demented 3 1895 M R 80 12 \n", + "\n", + " SES MMSE CDR eTIV nWBV ASF \n", + "0 2.0 27.0 0.0 1987 0.696 0.883 \n", + "1 2.0 30.0 0.0 2004 0.681 0.876 \n", + "2 NaN 23.0 0.5 1678 0.736 1.046 \n", + "3 NaN 28.0 0.5 1738 0.713 1.010 \n", + "4 NaN 22.0 0.5 1698 0.701 1.034 " + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.head()\n", + "# data.tail()\n", + "# data.size" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:42.468495Z", + "iopub.status.busy": "2021-02-21T05:24:42.467769Z", + "iopub.status.idle": "2021-02-21T05:24:42.473077Z", + "shell.execute_reply": "2021-02-21T05:24:42.472435Z" + }, + "papermill": { + "duration": 0.054333, + "end_time": "2021-02-21T05:24:42.473216", + "exception": false, + "start_time": "2021-02-21T05:24:42.418883", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(373, 15)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:42.571285Z", + "iopub.status.busy": "2021-02-21T05:24:42.570528Z", + "iopub.status.idle": "2021-02-21T05:24:42.586722Z", + "shell.execute_reply": "2021-02-21T05:24:42.585742Z" + }, + "papermill": { + "duration": 0.067194, + "end_time": "2021-02-21T05:24:42.586897", + "exception": false, + "start_time": "2021-02-21T05:24:42.519703", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "RangeIndex: 373 entries, 0 to 372\n", + "Data columns (total 15 columns):\n", + " # Column Non-Null Count Dtype \n", + "--- ------ -------------- ----- \n", + " 0 Subject ID 373 non-null object \n", + " 1 MRI ID 373 non-null object \n", + " 2 Group 373 non-null object \n", + " 3 Visit 373 non-null int64 \n", + " 4 MR Delay 373 non-null int64 \n", + " 5 M/F 373 non-null object \n", + " 6 Hand 373 non-null object \n", + " 7 Age 373 non-null int64 \n", + " 8 EDUC 373 non-null int64 \n", + " 9 SES 354 non-null float64\n", + " 10 MMSE 371 non-null float64\n", + " 11 CDR 373 non-null float64\n", + " 12 eTIV 373 non-null int64 \n", + " 13 nWBV 373 non-null float64\n", + " 14 ASF 373 non-null float64\n", + "dtypes: float64(5), int64(5), object(5)\n", + "memory usage: 43.8+ KB\n" + ] + } + ], + "source": [ + "data.info()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.046597, + "end_time": "2021-02-21T05:24:42.681879", + "exception": false, + "start_time": "2021-02-21T05:24:42.635282", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Converting Categorical Data to Numerical Data" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.046699, + "end_time": "2021-02-21T05:24:42.776013", + "exception": false, + "start_time": "2021-02-21T05:24:42.729314", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "When **inplace = True** , the data is modified in place, which means it will return nothing and the dataframe is now updated. \n", + "When **inplace = False** , which is the *default*, then the operation is performed and it returns a copy of the object. You then need to save it to something." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.046777, + "end_time": "2021-02-21T05:24:42.872082", + "exception": false, + "start_time": "2021-02-21T05:24:42.825305", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "set axis=0 for rows or, just put axis='rows' to access the rows\n", + "\n", + "set axis=1 for columns or, just put axis='columns' to access the columns" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:42.982688Z", + "iopub.status.busy": "2021-02-21T05:24:42.981824Z", + "iopub.status.idle": "2021-02-21T05:24:42.985482Z", + "shell.execute_reply": "2021-02-21T05:24:42.986097Z" + }, + "papermill": { + "duration": 0.067422, + "end_time": "2021-02-21T05:24:42.986242", + "exception": false, + "start_time": "2021-02-21T05:24:42.918820", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "RangeIndex: 373 entries, 0 to 372\n", + "Data columns (total 15 columns):\n", + " # Column Non-Null Count Dtype \n", + "--- ------ -------------- ----- \n", + " 0 Subject ID 373 non-null object \n", + " 1 MRI ID 373 non-null object \n", + " 2 Group 373 non-null int64 \n", + " 3 Visit 373 non-null int64 \n", + " 4 MR Delay 373 non-null int64 \n", + " 5 M/F 373 non-null int64 \n", + " 6 Hand 373 non-null object \n", + " 7 Age 373 non-null int64 \n", + " 8 EDUC 373 non-null int64 \n", + " 9 SES 354 non-null float64\n", + " 10 MMSE 371 non-null float64\n", + " 11 CDR 373 non-null float64\n", + " 12 eTIV 373 non-null int64 \n", + " 13 nWBV 373 non-null float64\n", + " 14 ASF 373 non-null float64\n", + "dtypes: float64(5), int64(7), object(3)\n", + "memory usage: 43.8+ KB\n" + ] + } + ], + "source": [ + "data['M/F'] = [1 if each == \"M\" else 0 for each in data['M/F']]\n", + "data['Group'] = [1 if each == \"Demented\" or each == \"Converted\" else 0 for each in data['Group']]\n", + "# data['Group'] = data['Group'].replace(['Converted'], ['Demented']) # Target variable\n", + "# data['Group'] = data['Group'].replace(['Demented', 'Nondemented'], [1,0]) # Target variable\n", + "data.info()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.046981, + "end_time": "2021-02-21T05:24:43.080526", + "exception": false, + "start_time": "2021-02-21T05:24:43.033545", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "Note: Based on the given data **CDR** is used to tell what the condition of the patient meaning, does the patient has any dementia or, not.\n", + "\n", + "CDR Value Meaning:\n", + "\n", + "* 0 ---> Normal\n", + "* 0.5 ---> Very Mild Dementia\n", + "* 1 ---> Mild Dementia\n", + "* 2 ---> Moderate Dementia\n", + "* 3 ---> Severe Dementia" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.047171, + "end_time": "2021-02-21T05:24:43.175738", + "exception": false, + "start_time": "2021-02-21T05:24:43.128567", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Correlation Between Attributes" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:43.274629Z", + "iopub.status.busy": "2021-02-21T05:24:43.273593Z", + "iopub.status.idle": "2021-02-21T05:24:43.287746Z", + "shell.execute_reply": "2021-02-21T05:24:43.287173Z" + }, + "papermill": { + "duration": 0.064432, + "end_time": "2021-02-21T05:24:43.287873", + "exception": false, + "start_time": "2021-02-21T05:24:43.223441", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Group 1.000000\n", + "CDR 0.778049\n", + "M/F 0.222146\n", + "SES 0.062463\n", + "ASF 0.032495\n", + "Age -0.005941\n", + "eTIV -0.042700\n", + "Visit -0.095507\n", + "MR Delay -0.120638\n", + "EDUC -0.193060\n", + "nWBV -0.311346\n", + "MMSE -0.524775\n", + "Name: Group, dtype: float64" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "correlation_matrix = data.corr()\n", + "data_corr = correlation_matrix['Group'].sort_values(ascending=False)\n", + "data_corr" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:43.387454Z", + "iopub.status.busy": "2021-02-21T05:24:43.386736Z", + "iopub.status.idle": "2021-02-21T05:24:45.327091Z", + "shell.execute_reply": "2021-02-21T05:24:45.326421Z" + }, + "papermill": { + "duration": 1.99135, + "end_time": "2021-02-21T05:24:45.327222", + "exception": false, + "start_time": "2021-02-21T05:24:43.335872", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[,\n", + " ,\n", + " ,\n", + " ,\n", + " ],\n", + " [,\n", + " ,\n", + " ,\n", + " ,\n", + " ],\n", + " [,\n", + " ,\n", + " ,\n", + " ,\n", + " ],\n", + " [,\n", + " ,\n", + " ,\n", + " ,\n", + " ],\n", + " [,\n", + " ,\n", + " ,\n", + " ,\n", + " ]],\n", + " dtype=object)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "from pandas.plotting import scatter_matrix\n", + "\n", + "attributes = [\"Group\", \"CDR\", \"M/F\", \"SES\", \"ASF\"]\n", + "\n", + "scatter_matrix(data[attributes], figsize=(15, 11), alpha=0.3)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:45.433703Z", + "iopub.status.busy": "2021-02-21T05:24:45.433017Z", + "iopub.status.idle": "2021-02-21T05:24:47.540106Z", + "shell.execute_reply": "2021-02-21T05:24:47.540628Z" + }, + "papermill": { + "duration": 2.163236, + "end_time": "2021-02-21T05:24:47.540769", + "exception": false, + "start_time": "2021-02-21T05:24:45.377533", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + " \n", + " " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import plotly.express as px\n", + "\n", + "fig = px.scatter(data, x='Group', y='SES', color='Group')\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:47.650865Z", + "iopub.status.busy": "2021-02-21T05:24:47.649810Z", + "iopub.status.idle": "2021-02-21T05:24:47.721817Z", + "shell.execute_reply": "2021-02-21T05:24:47.722413Z" + }, + "papermill": { + "duration": 0.129299, + "end_time": "2021-02-21T05:24:47.722563", + "exception": false, + "start_time": "2021-02-21T05:24:47.593264", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import plotly.express as px\n", + "\n", + "fig = px.scatter(data, x='Group', y='Age', color='Group')\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:47.832887Z", + "iopub.status.busy": "2021-02-21T05:24:47.831861Z", + "iopub.status.idle": "2021-02-21T05:24:47.904040Z", + "shell.execute_reply": "2021-02-21T05:24:47.904522Z" + }, + "papermill": { + "duration": 0.130189, + "end_time": "2021-02-21T05:24:47.904674", + "exception": false, + "start_time": "2021-02-21T05:24:47.774485", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import plotly.express as px\n", + "\n", + "fig = px.scatter(data, x='Group', y='ASF', color='Group')\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.053096, + "end_time": "2021-02-21T05:24:48.011350", + "exception": false, + "start_time": "2021-02-21T05:24:47.958254", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Checking For Missig/Null Values" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:48.126795Z", + "iopub.status.busy": "2021-02-21T05:24:48.125994Z", + "iopub.status.idle": "2021-02-21T05:24:48.130278Z", + "shell.execute_reply": "2021-02-21T05:24:48.129746Z" + }, + "papermill": { + "duration": 0.065537, + "end_time": "2021-02-21T05:24:48.130397", + "exception": false, + "start_time": "2021-02-21T05:24:48.064860", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Subject ID 0\n", + "MRI ID 0\n", + "Group 0\n", + "Visit 0\n", + "MR Delay 0\n", + "M/F 0\n", + "Hand 0\n", + "Age 0\n", + "EDUC 0\n", + "SES 19\n", + "MMSE 2\n", + "CDR 0\n", + "eTIV 0\n", + "nWBV 0\n", + "ASF 0\n", + "dtype: int64" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.isnull().sum()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.053518, + "end_time": "2021-02-21T05:24:48.237952", + "exception": false, + "start_time": "2021-02-21T05:24:48.184434", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### Taking median values for the missing values of MMSE" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:48.355837Z", + "iopub.status.busy": "2021-02-21T05:24:48.355128Z", + "iopub.status.idle": "2021-02-21T05:24:48.357901Z", + "shell.execute_reply": "2021-02-21T05:24:48.358483Z" + }, + "papermill": { + "duration": 0.066666, + "end_time": "2021-02-21T05:24:48.358626", + "exception": false, + "start_time": "2021-02-21T05:24:48.291960", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Subject ID 0\n", + "MRI ID 0\n", + "Group 0\n", + "Visit 0\n", + "MR Delay 0\n", + "M/F 0\n", + "Hand 0\n", + "Age 0\n", + "EDUC 0\n", + "SES 19\n", + "MMSE 0\n", + "CDR 0\n", + "eTIV 0\n", + "nWBV 0\n", + "ASF 0\n", + "dtype: int64" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "median = data['MMSE'].median()\n", + "data['MMSE'].fillna(median, inplace=True)\n", + "data.isnull().sum()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.054757, + "end_time": "2021-02-21T05:24:48.468036", + "exception": false, + "start_time": "2021-02-21T05:24:48.413279", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### Taking median values for the missing values of SES" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:48.580903Z", + "iopub.status.busy": "2021-02-21T05:24:48.580254Z", + "iopub.status.idle": "2021-02-21T05:24:48.589172Z", + "shell.execute_reply": "2021-02-21T05:24:48.589795Z" + }, + "papermill": { + "duration": 0.067076, + "end_time": "2021-02-21T05:24:48.589954", + "exception": false, + "start_time": "2021-02-21T05:24:48.522878", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Subject ID 0\n", + "MRI ID 0\n", + "Group 0\n", + "Visit 0\n", + "MR Delay 0\n", + "M/F 0\n", + "Hand 0\n", + "Age 0\n", + "EDUC 0\n", + "SES 0\n", + "MMSE 0\n", + "CDR 0\n", + "eTIV 0\n", + "nWBV 0\n", + "ASF 0\n", + "dtype: int64" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "median = data['SES'].median()\n", + "data['SES'].fillna(median, inplace=True)\n", + "data.isnull().sum()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.054489, + "end_time": "2021-02-21T05:24:48.700670", + "exception": false, + "start_time": "2021-02-21T05:24:48.646181", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Train-Test Split" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.055984, + "end_time": "2021-02-21T05:24:48.812635", + "exception": false, + "start_time": "2021-02-21T05:24:48.756651", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Prepare the data for X and y where, \n", + "\n", + "1. X = The columns/features for **making the prediction**\n", + "2. y = The **predicted value**" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:48.926932Z", + "iopub.status.busy": "2021-02-21T05:24:48.926288Z", + "iopub.status.idle": "2021-02-21T05:24:48.933312Z", + "shell.execute_reply": "2021-02-21T05:24:48.932717Z" + }, + "papermill": { + "duration": 0.065077, + "end_time": "2021-02-21T05:24:48.933436", + "exception": false, + "start_time": "2021-02-21T05:24:48.868359", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "y = data['Group'].values\n", + "X = data[['M/F', 'Age', 'EDUC', 'SES', 'MMSE', 'eTIV', 'nWBV', 'ASF']]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.055011, + "end_time": "2021-02-21T05:24:49.044676", + "exception": false, + "start_time": "2021-02-21T05:24:48.989665", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### Train-Test distribution Without Stratified Sampling" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:49.159712Z", + "iopub.status.busy": "2021-02-21T05:24:49.158704Z", + "iopub.status.idle": "2021-02-21T05:24:49.174698Z", + "shell.execute_reply": "2021-02-21T05:24:49.175197Z" + }, + "papermill": { + "duration": 0.075266, + "end_time": "2021-02-21T05:24:49.175352", + "exception": false, + "start_time": "2021-02-21T05:24:49.100086", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "In Training Split:\n", + "0 158\n", + "1 140\n", + "Name: 0, dtype: int64\n", + "\n", + "In Testing Split:\n", + "1 43\n", + "0 32\n", + "Name: 0, dtype: int64\n" + ] + } + ], + "source": [ + "# by default test_size= 0.25\n", + "X_trainval, X_test, y_trainval, y_test = train_test_split(X, y, test_size= 0.20, random_state=42)\n", + "\n", + "df_ytrain = pd.DataFrame(y_trainval)\n", + "df_ytest = pd.DataFrame(y_test)\n", + "\n", + "print('In Training Split:')\n", + "print(df_ytrain[0].value_counts())\n", + "\n", + "print('\\nIn Testing Split:')\n", + "print(df_ytest[0].value_counts())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.055463, + "end_time": "2021-02-21T05:24:49.286592", + "exception": false, + "start_time": "2021-02-21T05:24:49.231129", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### With Stratified Sampling" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:49.402123Z", + "iopub.status.busy": "2021-02-21T05:24:49.401433Z", + "iopub.status.idle": "2021-02-21T05:24:49.418062Z", + "shell.execute_reply": "2021-02-21T05:24:49.417434Z" + }, + "papermill": { + "duration": 0.075377, + "end_time": "2021-02-21T05:24:49.418177", + "exception": false, + "start_time": "2021-02-21T05:24:49.342800", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "In Training Split:\n", + "0 152\n", + "1 146\n", + "Name: 0, dtype: int64\n", + "\n", + "In Testing Split:\n", + "0 38\n", + "1 37\n", + "Name: 0, dtype: int64\n" + ] + } + ], + "source": [ + "# by default test_size= 0.25\n", + "X_trainval, X_test, y_trainval, y_test = train_test_split(X, y, test_size= 0.20, random_state=42, stratify=y)\n", + "\n", + "\n", + "df_ytrain = pd.DataFrame(y_trainval)\n", + "df_ytest = pd.DataFrame(y_test)\n", + "\n", + "print('In Training Split:')\n", + "print(df_ytrain[0].value_counts())\n", + "\n", + "print('\\nIn Testing Split:')\n", + "print(df_ytest[0].value_counts())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.056065, + "end_time": "2021-02-21T05:24:49.530652", + "exception": false, + "start_time": "2021-02-21T05:24:49.474587", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### Scale the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:49.651835Z", + "iopub.status.busy": "2021-02-21T05:24:49.650741Z", + "iopub.status.idle": "2021-02-21T05:24:49.660072Z", + "shell.execute_reply": "2021-02-21T05:24:49.659484Z" + }, + "papermill": { + "duration": 0.073448, + "end_time": "2021-02-21T05:24:49.660201", + "exception": false, + "start_time": "2021-02-21T05:24:49.586753", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# here StandardScaler() means z = (x - u) / s\n", + "scaler = StandardScaler().fit(X_trainval)\n", + "#scaler = MinMaxScaler().fit(X_trainval)\n", + "X_trainval_scaled = scaler.transform(X_trainval)\n", + "X_test_scaled = scaler.transform(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:49.777554Z", + "iopub.status.busy": "2021-02-21T05:24:49.776541Z", + "iopub.status.idle": "2021-02-21T05:24:49.782686Z", + "shell.execute_reply": "2021-02-21T05:24:49.783257Z" + }, + "papermill": { + "duration": 0.066296, + "end_time": "2021-02-21T05:24:49.783400", + "exception": false, + "start_time": "2021-02-21T05:24:49.717104", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[-0.87966444, 0.38449006, -0.87500081, ..., -0.30880564,\n", + " 0.16961408, 0.21548547],\n", + " [ 1.13679712, 1.04832574, -0.87500081, ..., 1.23805919,\n", + " -0.67676996, -1.21558429],\n", + " [-0.87966444, -0.9431813 , -1.22928103, ..., -1.08511327,\n", + " 0.4605586 , 1.14751551],\n", + " ...,\n", + " [-0.87966444, -0.01381135, 1.2506805 , ..., -0.92985174,\n", + " 0.01091708, 0.94202857],\n", + " [-0.87966444, -1.20871557, -0.87500081, ..., -0.00978344,\n", + " 0.5663566 , -0.10742258],\n", + " [-0.87966444, 0.11895579, 1.2506805 , ..., -1.38413547,\n", + " 0.4605586 , 1.56582821]])" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X_trainval_scaled" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:49.907580Z", + "iopub.status.busy": "2021-02-21T05:24:49.906515Z", + "iopub.status.idle": "2021-02-21T05:24:49.938315Z", + "shell.execute_reply": "2021-02-21T05:24:49.937670Z" + }, + "papermill": { + "duration": 0.09737, + "end_time": "2021-02-21T05:24:49.938439", + "exception": false, + "start_time": "2021-02-21T05:24:49.841069", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
M/FAgeEDUCSESMMSEeTIVnWBVASF
count298.000000298.000000298.000000298.000000298.000000298.000000298.000000298.000000
mean0.43624277.10402714.4697992.48993327.3557051483.7013420.7305871.198638
std0.4967527.5446542.8273721.1168593.689231174.1926490.0378710.136491
min0.00000060.0000006.0000001.0000004.0000001106.0000000.6440000.883000
25%0.00000071.25000012.0000002.00000027.0000001357.0000000.6992501.107250
50%0.00000077.00000014.0000002.00000029.0000001462.0000000.7310001.200500
75%1.00000082.00000016.0000003.00000030.0000001585.2500000.7570001.293000
max1.00000096.00000023.0000005.00000030.0000001987.0000000.8370001.587000
\n", + "
" + ], + "text/plain": [ + " M/F Age EDUC SES MMSE \\\n", + "count 298.000000 298.000000 298.000000 298.000000 298.000000 \n", + "mean 0.436242 77.104027 14.469799 2.489933 27.355705 \n", + "std 0.496752 7.544654 2.827372 1.116859 3.689231 \n", + "min 0.000000 60.000000 6.000000 1.000000 4.000000 \n", + "25% 0.000000 71.250000 12.000000 2.000000 27.000000 \n", + "50% 0.000000 77.000000 14.000000 2.000000 29.000000 \n", + "75% 1.000000 82.000000 16.000000 3.000000 30.000000 \n", + "max 1.000000 96.000000 23.000000 5.000000 30.000000 \n", + "\n", + " eTIV nWBV ASF \n", + "count 298.000000 298.000000 298.000000 \n", + "mean 1483.701342 0.730587 1.198638 \n", + "std 174.192649 0.037871 0.136491 \n", + "min 1106.000000 0.644000 0.883000 \n", + "25% 1357.000000 0.699250 1.107250 \n", + "50% 1462.000000 0.731000 1.200500 \n", + "75% 1585.250000 0.757000 1.293000 \n", + "max 1987.000000 0.837000 1.587000 " + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X_trainval.describe()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.057412, + "end_time": "2021-02-21T05:24:50.054067", + "exception": false, + "start_time": "2021-02-21T05:24:49.996655", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Data Visualization" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:50.175069Z", + "iopub.status.busy": "2021-02-21T05:24:50.174376Z", + "iopub.status.idle": "2021-02-21T05:24:51.959428Z", + "shell.execute_reply": "2021-02-21T05:24:51.958796Z" + }, + "papermill": { + "duration": 1.847829, + "end_time": "2021-02-21T05:24:51.959546", + "exception": false, + "start_time": "2021-02-21T05:24:50.111717", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "X_trainval.hist(bins=30, figsize=(20,15))\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:52.090627Z", + "iopub.status.busy": "2021-02-21T05:24:52.089685Z", + "iopub.status.idle": "2021-02-21T05:24:52.168509Z", + "shell.execute_reply": "2021-02-21T05:24:52.167818Z" + }, + "papermill": { + "duration": 0.150146, + "end_time": "2021-02-21T05:24:52.168625", + "exception": false, + "start_time": "2021-02-21T05:24:52.018479", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import plotly.express as px\n", + "\n", + "x = ['M/F', 'Age', 'EDUC', 'SES', 'MMSE', 'eTIV', 'nWBV', 'ASF']\n", + "\n", + "fig = px.histogram(X_trainval, x='eTIV', nbins=50)\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:52.349421Z", + "iopub.status.busy": "2021-02-21T05:24:52.321695Z", + "iopub.status.idle": "2021-02-21T05:24:52.362716Z", + "shell.execute_reply": "2021-02-21T05:24:52.362055Z" + }, + "papermill": { + "duration": 0.133606, + "end_time": "2021-02-21T05:24:52.362833", + "exception": false, + "start_time": "2021-02-21T05:24:52.229227", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import plotly.express as px\n", + "\n", + "x = ['M/F', 'Age', 'EDUC', 'SES', 'MMSE', 'eTIV', 'nWBV', 'ASF']\n", + "\n", + "fig = px.scatter(X_trainval, x='eTIV')\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.060847, + "end_time": "2021-02-21T05:24:52.484599", + "exception": false, + "start_time": "2021-02-21T05:24:52.423752", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# SVM" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:52.614683Z", + "iopub.status.busy": "2021-02-21T05:24:52.613922Z", + "iopub.status.idle": "2021-02-21T05:24:52.791187Z", + "shell.execute_reply": "2021-02-21T05:24:52.790597Z" + }, + "papermill": { + "duration": 0.245707, + "end_time": "2021-02-21T05:24:52.791313", + "exception": false, + "start_time": "2021-02-21T05:24:52.545606", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "clf_svm = SVC(random_state=42)\n", + "clf_svm.fit(X_trainval_scaled, y_trainval)\n", + "\n", + "# for test there are 94 cases\n", + "plot_confusion_matrix(clf_svm, \n", + " X_test_scaled, \n", + " y_test, \n", + " values_format='d', \n", + " display_labels=['Nondemented', 'Demented'])" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:52.925317Z", + "iopub.status.busy": "2021-02-21T05:24:52.924620Z", + "iopub.status.idle": "2021-02-21T05:24:52.935820Z", + "shell.execute_reply": "2021-02-21T05:24:52.935250Z" + }, + "papermill": { + "duration": 0.08155, + "end_time": "2021-02-21T05:24:52.935942", + "exception": false, + "start_time": "2021-02-21T05:24:52.854392", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train accuracy 0.8590604026845637\n", + "Test accuracy 0.7466666666666667\n", + "Test recall 0.5945945945945946\n", + "Test AUC 0.744665718349929\n" + ] + } + ], + "source": [ + "train_score = 0\n", + "test_score = 0\n", + "test_recall = 0\n", + "test_auc = 0\n", + "\n", + "train_score = clf_svm.score(X_trainval_scaled, y_trainval)\n", + "test_score = clf_svm.score(X_test_scaled, y_test)\n", + "y_predict = clf_svm.predict(X_test_scaled)\n", + "\n", + "test_recall = recall_score(y_test, y_predict)\n", + "fpr, tpr, thresholds = roc_curve(y_test, y_predict)\n", + "test_auc = auc(fpr, tpr)\n", + "\n", + "\n", + "print(\"Train accuracy \", train_score)\n", + "print(\"Test accuracy \", test_score)\n", + "print(\"Test recall\", test_recall)\n", + "print(\"Test AUC\", test_auc)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.062406, + "end_time": "2021-02-21T05:24:53.061149", + "exception": false, + "start_time": "2021-02-21T05:24:52.998743", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### Optimize parameters(Finetuning) --> GridSearchCV() for SVM" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:24:53.196144Z", + "iopub.status.busy": "2021-02-21T05:24:53.195430Z", + "iopub.status.idle": "2021-02-21T05:25:04.937936Z", + "shell.execute_reply": "2021-02-21T05:25:04.938471Z" + }, + "papermill": { + "duration": 11.814617, + "end_time": "2021-02-21T05:25:04.938646", + "exception": false, + "start_time": "2021-02-21T05:24:53.124029", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'C': 6, 'gamma': 1, 'kernel': 'rbf'}\n" + ] + } + ], + "source": [ + "# Normally, C = 1 and gamma = 'scale' are default values\n", + "# C controls how wide the margin will be with respect to how many misclassification we are allowing\n", + "# C is increasing --> reduce the size of the margin and fewer misclassification and vice versa\n", + "param_grid = [\n", + " {'C': [0.5, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 30, 50, 80, 100],\n", + " 'gamma': ['scale', 0.5, 1, 0.1, 0.01, 0.001, 0.0001, 0.00001],\n", + " 'kernel': ['rbf', 'linear', 'poly', 'sigmoid']},\n", + "]\n", + "\n", + "optimal_params = GridSearchCV(SVC(),\n", + " param_grid,\n", + " cv=5, # we are taking 5-fold as in k-fold cross validation\n", + " scoring='accuracy', # try the other scoring if have time\n", + " verbose=0,\n", + " n_jobs=-1)\n", + "\n", + "optimal_params.fit(X_trainval_scaled, y_trainval)\n", + "print(optimal_params.best_params_)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:25:05.073634Z", + "iopub.status.busy": "2021-02-21T05:25:05.072673Z", + "iopub.status.idle": "2021-02-21T05:25:05.076294Z", + "shell.execute_reply": "2021-02-21T05:25:05.075675Z" + }, + "papermill": { + "duration": 0.074232, + "end_time": "2021-02-21T05:25:05.076406", + "exception": false, + "start_time": "2021-02-21T05:25:05.002174", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "C = optimal_params.best_params_['C']\n", + "gamma = optimal_params.best_params_['gamma']\n", + "kernel = optimal_params.best_params_['kernel']" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:25:05.207035Z", + "iopub.status.busy": "2021-02-21T05:25:05.206326Z", + "iopub.status.idle": "2021-02-21T05:25:05.382139Z", + "shell.execute_reply": "2021-02-21T05:25:05.382632Z" + }, + "papermill": { + "duration": 0.242652, + "end_time": "2021-02-21T05:25:05.382801", + "exception": false, + "start_time": "2021-02-21T05:25:05.140149", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXMAAAEKCAYAAADgl7WbAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3deZwcVb338c93JoGEJAQiSQyyIwYRTYCARoQHBNmuC/iAy0XlKo+gAkFluXj1uYBe7kXZXFiDIIuAgCyyKGETA4KShZCQhACyCpGQsIYlyUz/7h91BppxZrpr6Jmu7vm+edVrqms59evq8OvTp06dUkRgZmaNraXeAZiZ2TvnZG5m1gSczM3MmoCTuZlZE3AyNzNrAk7mZmZNwMnczKxOJA2RdK+k+yXNl3R8Wn6cpKclzUnTXhXLcj9zM7P6kCRgWEQslzQYuAs4HNgDWB4RJ1db1qA+itHMzCqIrDa9PL0cnKZe1bCdzAtknVGtsdH6g+sdhuXw0Nw16h2C5fQKLyyNiNG93X/3nYfFsufbq9p21twV84E3yhZNjYip5dtIagVmAe8FzoiIv0raEzhU0leAmcAREfFCT8dyM0uBTJowJO6dtn69w7Acdl93Yr1DsJxujd/OiohJvd0/+/90g6q2bR33cNXHkrQWcA1wGPAcsJSslv4jYFxEfK2n/X0B1MwshwBKVf6Xq9yIF4E7gD0i4tmIaI+IEnAusF2l/Z3MzcxyCIJV0V7VVImk0alGjqShwK7Ag5LGlW22D/BApbLcZm5mllPeWncPxgEXpnbzFuCKiLhB0sWSJpL9EHgcOLhSQU7mZmY5BEF7ja41RsRcYKsuln85b1lO5mZmOZV613uwTzmZm5nlEEC7k7mZWeNzzdzMrMEFsKqA9+c4mZuZ5RCEm1nMzBpeQHvxcrmTuZlZHtkdoMXjZG5mlotoR/UO4p84mZuZ5ZBdAHUyNzNraFk/cydzM7OGV3LN3MyssblmbmbWBALRXsDRw53MzcxycjOLmVmDC8TKaK13GP/EydzMLIfspiE3s5iZNTxfADUza3ARoj1cMzcza3gl18zNzBpbdgG0eKmzeBGZmRWYL4CamTWJdvczNzNrbL4D1MysSZTcm8XMrLFlA205mZuZNbRArKrR7fyShgDTgdXJ8vFvI+JYSaOAy4GNgMeBz0XECz2VVbyvFzOzAouA9mipaqrCCuDjETEBmAjsIekjwDHAbRGxGXBbet0jJ3Mzs1xEqcqpksgsTy8HpymAzwAXpuUXAntXKsvNLGZmOQTkuZ1/HUkzy15PjYip5RtIagVmAe8FzoiIv0oaGxGLASJisaQxlQ7kZG5mllOOC6BLI2JSTxtERDswUdJawDWStuxNTE7mZmY5BOqTh1NExIuS7gD2AJ6VNC7VyscBSyrt7zZzM7McAlgVg6qaKpE0OtXIkTQU2BV4ELgOOCBtdgDwu0pluWZuZpaLajme+TjgwtRu3gJcERE3SLoHuELSgcCTwH6VCnIyNzPLIajdHaARMRfYqovly4Bd8pTlZG5mlpOfNGRm1uAi5LFZzMwaXXYBtDa389eSk7mZWS5+BqiZWcPLLoC6zdzMrOF5CFwzswbXV3eAvlNO5mZmOfmBzmZmDS4CVpWczM3MGlrWzOJkbmbW8HwHqDW9lW+IIz77XlatbKG9DXb4l5f4ylH/4OKT380fLh3FyFHtAHz1e8+w3S6v1Dla62zw6iVOufoRBq8WtA4K7rxxLS4++d31DqtQBlzXREkBnBoRR6TXRwLDI+K4GpR9AXBDRPz2nZbVi2PvDTwUEQty7rc8Iob3UViFMXj14CdX/o2hw0q0rYLv7r0Z2378ZQD2+fpz7PfN5+ocofVk1Qpx9H6b8sZrrbQOCk699hFm3D6CB2cPq3doBVLMZpa+jGgF8FlJ6/ThMephb2CLegdRVBIMHVYCoG2VaF8lVLxKjHVLvPFadqv6oMFB6+Agos4hFVCtngFaS32ZzNuAqcB3Oq+QtKGk2yTNTX83SMsvkPRzSXdLelTSvmm5JJ0uaYGkG4ExZWVtI+lPkmZJmpaeyoGkOySdJmm6pIWStpV0taSHJf1X2f5fknSvpDmSzknjCiNpuaQTJN0v6S+Sxkr6KPBp4KS0/aZpuikd/05Jm6f9N5Z0j6QZkn7UZ2e5gNrb4Zu7jufzH9qSrXZ8hc23fg2A6381mm/sMp5TvrM+r7xYvLEtLNPSEpx5yyIunzuf+6YPZ9F9rpWXy3qztFY19ae+/q1wBrC/pJGdlp8OXBQRHwIuAX5etm4c8DHgk8CJadk+wHjgg8DXgY8CSBoM/ALYNyK2Ac4HTigra2VE7AicTfakjkOALYF/k/QuSe8HPg9sHxETgXZg/7TvMOAvETEBmA58PSLuJnsCyFERMTEi/kb2hXVYOv6RwJlp/58BZ0XEtsA/ujtBkg6SNFPSzOeWtXd7IhtJayucdesiLpm1gEVz1uDxB4fwyQOW8qt7FnDmLYsYNXYVU49ft95hWjdKJfGtT4xn/222YPzE19hw/Ov1DqlQOm4aqmbqT32azCPiZeAiYEqnVZOBS9P8xWTJu8O1EVFKbdJj07Idgcsioj0ingFuT8vHkyXnWyTNAX4ArFdW1nXp7zxgfkQsjogVwKPA+mSDv28DzEj77wJskvZZCdyQ5mcBG3V+f5KGk32xXJn2P4fsywhge+CysvfYpYiYGhGTImLS6Hc1V211+Mh2Jkxezow/jmDt0W20tkJLC+y5//MsmrNGvcOzCl59uZX77xnOtjv7QnVnRWxm6Y/eLD8FZgO/6mGb8la5FWXz6mab8vXzI2JyN+V2lFXqVG6J7L0LuDAivtfFvqsi3mwtbKfrc9UCvJhq9V0ZcK2NLy5rZdCgLJGveF3MvnMEnztkCcueHcS7xrYBcPcfRrLR+DfqHKl1ZeSoNtraxKsvt7LakBJb77CcK84YU3nHAWTA9WbpEBHPS7oCOJCsGQTgbuALZDXW/YG7KhQzHThY0kVk7eU7k9XsFwGjJU2OiHtSs8v7ImJ+leHdBvxO0mkRsUTSKGBERDzRwz6vACPSe3tZ0mOS9ouIKyUJ+FBE3A/8Ob3HX/NW003Te/7ZwZx8+AaUSqJUgh0/9SIf+cTL/OSwDfjb/KFIMHa9lUz5yVP1DtW6MGrsKo782ZO0tGS/oqZfP5K/3rpmvcMqnCL2ZumvfuanAIeWvZ4CnC/pKOA54KsV9r8G+DhZc8lDwJ8AImJlukj689QuP4jsl0BVyTwiFkj6AXCzpBZgFVm7ek/J/DfAuZKmAPuSJeqzUjmD0/r7gcOBSyUdDlxVTTzNYJMt3uDMWx76p+VH/+LJOkRjeT22cCiH7Da+3mEUWoRoK2AyV7jfUWFMmjAk7p22fr3DsBx2X7e7FjYrqlvjt7MiYlJv91978zGx03n7VbXttR878x0dKw/fAWpmlsOAbTM3M2s2TuZmZg3OD6cwM2sS/d2HvBrFuyRrZlZgEdBWaqlqqkTS+pL+mIYcmZ96vyHpOElPp2FD5kjaq1JZrpmbmeVUw2aWNuCIiJgtaQQwS9Itad1pEXFytQU5mZuZ5VDLNvOIWAwsTvOvSFoIvKc3ZbmZxcwspwhVNQHrdAykl6aDuitT0kbAVsBf06JD08iy50tau1JMTuZmZjnlGGhracdAemma2lV5adC+q4BvpwEKzwI2BSaS1dxPqRSTm1nMzHKIqG0/8zSm1FXAJRFxdXaMeLZs/bm8NYJrt5zMzcxyEe1V9FSpqqRscL7zgIURcWrZ8nGpPR2y5zk8UKksJ3Mzs5yidjXz7YEvA/PSMxEA/gP4oqSJZKMHPA4cXKkgJ3MzsxxqOTZLRNwFXd6B9Pu8ZTmZm5nlERTyIddO5mZmORXxdn4nczOzHKKGF0BrycnczCwnN7OYmTWBGvZmqRknczOzHCKczM3MmoIfTmFm1gTcZm5m1uACUXJvFjOzxlfAirmTuZlZLr4AambWJApYNXcyNzPLqaFq5pJ+QQ/fPxExpU8iMjMrsABKpQZK5sDMfovCzKxRBNBINfOIuLD8taRhEfFq34dkZlZsRexnXrGzpKTJkhYAC9PrCZLO7PPIzMyKKqqc+lE1Pd9/CuwOLAOIiPuBHfsyKDOz4hIR1U39qareLBHxVPbc0Te19004ZmYNoIDNLNUk86ckfRQISasBU0hNLmZmA05AFLA3SzXNLN8ADgHeAzwNTEyvzcwGKFU59Z+KNfOIWArs3w+xmJk1hgI2s1TTm2UTSddLek7SEkm/k7RJfwRnZlZIDdqb5VLgCmAcsC5wJXBZXwZlZlZYHTcNVTP1o2qSuSLi4ohoS9OvKeSPDDOz/pE9Oq7y1J+6TeaSRkkaBfxR0jGSNpK0oaSjgRv7L0Qzs4IpqbqpAknrS/qjpIWS5ks6PC0fJekWSQ+nv2tXKqunC6CzyGrgHREdXLYugB9VjNTMrAmpdrXuNuCIiJgtaQQwS9ItwL8Bt0XEiZKOAY4B/r2ngnoam2XjmoVrZtYsanhxMyIWA4vT/CuSFpJ1A/8MsFPa7ELgDnqbzMtJ2hLYAhhSFsRFOeM2M2sCuS5uriOpfATaqRExtctSpY2ArYC/AmNToiciFksaU+lAFZO5pGPJviG2AH4P7AncBTiZm9nAVH3NfGlETKq0kaThwFXAtyPi5U7Dp1Slmt4s+wK7AP+IiK8CE4DVcx/JzKxZlKqcqiBpMFkivyQirk6Ln5U0Lq0fByypVE41yfz1iCgBbZLWTIX6piEzG5hq2M9cWRX8PGBhRJxatuo64IA0fwDwu0plVdNmPlPSWsC5ZD1clgP3VrGfmVlTqmFvlu2BLwPzJM1Jy/4DOBG4QtKBwJPAfpUKqmZslm+l2bMl3QSsGRFzexW2mVkzqF1vlrvofkSuXfKU1dMDnbfuaV1EzM5zIDMz6zs91cxP6WFdAB+vcSwD3kNz12D3dSfWOwzL4TdP3V3vECynddZ752XUsJmlZnq6aWjn/gzEzKwhBFXdqt/fqrppyMzMyjRSzdzMzLrWUM0sZmbWjQIm82qeNCRJX5L0n+n1BpK26/vQzMwKqkGfNHQmMBn4Ynr9CnBGn0VkZlZgiuqn/lRNM8uHI2JrSfcBRMQLklbr47jMzIqrQXuzrJLUSvrRIGk0VQ8hY2bWfIp4AbSaZpafA9cAYySdQDb87X/3aVRmZkVWwDbzasZmuUTSLLJxAgTsHREL+zwyM7MiqkN7eDWqeTjFBsBrwPXlyyLiyb4MzMyssBoxmQM38taDnYcAGwOLgA/0YVxmZoWlAl41rKaZ5YPlr9Noigf3WURmZpZb7jtAI2K2pG37Ihgzs4bQiM0skr5b9rIF2Bp4rs8iMjMrska9AAqMKJtvI2tDv6pvwjEzawCNlszTzULDI+KoforHzKz4GimZSxoUEW09PT7OzGygEY3Xm+VesvbxOZKuA64EXu1YGRFX93FsZmbF08Bt5qOAZWTP/Ozobx6Ak7mZDUwNlszHpJ4sD/BWEu9QwLdiZtZPCpgBe0rmrcBw3p7EOxTwrZiZ9Y9Ga2ZZHBE/7LdIzMwaRQGTeU9D4BZv9HUzs3qLrDdLNVMlks6XtETSA2XLjpP0tKQ5adqrmrB6Sua7VFOAmdmAU7vxzC8A9uhi+WkRMTFNv6+moG6bWSLi+apCMTMbYGrVZh4R0yVtVIuyqnnSkJmZlau+Zr6OpJll00FVHuFQSXNTM8za1ezgZG5mlke1iTxL5ksjYlLZNLWKI5wFbApMBBYDp1QTVu4hcM3MBjLRt10TI+LZN48lnQvcUM1+rpmbmeWkqG7qVdnSuLKX+5DduFmRa+ZmZnnVqGYu6TJgJ7K29b8DxwI7SZqYjvI4VT7ZzcnczCyv2vVm+WIXi8/rTVlO5mZmeTTwqIlmZlbOydzMrPE12sMpzMysC25mMTNrdNWPu9KvnMzNzPJyMjcza2x9fQdobzmZm5nlpFLxsrmTuZlZHm4zNzNrDm5mMTNrBk7mZmaNzzVzM7Nm4GRuZtbgwrfzm5k1PPczNzNrFlG8bO5kbmaWk2vmNqAMXr3EKVc/wuDVgtZBwZ03rsXFJ7+73mFZJyvfEMfvuyWrVrZQahcf3msZ+x3x1Jvrrz97XS45YSOm3n8va45qq2OkBeGbhnpPUjswDxgMtAEXAj+NiH6/DCFpLeBfI+LMnPsdByyPiJP7JLACWrVCHL3fprzxWiutg4JTr32EGbeP4MHZw+odmpUZvHrw/y+fz5BhJdpWiWM/uyUTd36BzbZeztJnVmPenSNZ5z0r6h1moRTxAmhLvQOo0usRMTEiPgB8AtiL7MGn9bAW8K06HbvBiDdeawVg0OCgdXAUsalxwJNgyLAsO7W3ifY2ZVf5gIuO35j9v/9EMdsV6kil6qb+1CjJ/E0RsQQ4CDhUmVZJJ0maIWmupIMBJO0k6U+SrpD0kKQTJe0v6V5J8yRtmrYbLemqtP8MSdun5cdJOl/SHZIelTQlhXAisKmkOZJOStseVXb84ztilfR9SYsk3QqM78fTVBgtLcGZtyzi8rnzuW/6cBbd51p5EZXa4d93n8BBE7flgzu8xGZbLWfmzWsz6t0r2HCL1+odXrEE2QXQaqZ+1BDNLJ1FxKOSWoAxwGeAlyJiW0mrA3+WdHPadALwfuB54FHglxGxnaTDgcOAbwM/A06LiLskbQBMS/sAbA7sDIwAFkk6CzgG2DIiJgJI2g3YDNiOrD5znaQdgVeBLwBbkZ3n2cCszu9F0kFkX04MYY1anaLCKJXEtz4xnmFrtnPseY+x4fjXeWLR0HqHZZ20tMKPp93Pqy+1csrXN+eJhWtwzS/W4/uXLKh3aIVUxB8qDZnMk/RDkN2AD0naN70eSZZcVwIzImIxgKS/AR1Jfh5ZkgbYFdhC6iiONSWNSPM3RsQKYIWkJcDYLuLYLU33pdfD0/FHANdExGvp+Nd19SYiYiowNTvwqAL+E6mNV19u5f57hrPtzq84mRfYsJHtbDH5JWZOG8VzTw3h6N0nAPD84tX53p4TOOH6uaw1ZlWdoyyAAv6f2pDJXNImQDuwhCypHxYR0zptsxNQftWmVPa6xFvvvQWYHBGvd9qfTvu30/X5EvA/EXFOp/2/TSE/8v4zclQbbW3i1ZdbWW1Iia13WM4VZ4ypd1jWycvLBtE6KBg2sp2Vr7cw7861+PS3nmbqnBlvbnPo5K357xvnujcLvmmoZiSNBs4GTo+IkDQN+Kak2yNilaT3AU/nKPJm4FCgo/17YkTM6WH7V8hq3R2mAT+SdElELJf0HmAVMB24QNKJZOf5U8A5/1RaExs1dhVH/uxJWlqgpQWmXz+Sv966Zr3Dsk5eWLIaZ33nvZTaRakkJn9qKdvs+kK9wyquCD+c4h0YKmkOb3VNvBg4Na37JbARMFtZdfo5YO8cZU8BzpA0l+x8TAe+0d3GEbFM0p8lPQD8ISKOkvR+4J5Um18OfCkiZku6HJgDPAHcmSOmpvDYwqEcstuAvO7bUDZ8/2uceNPcHrc5/Z7Z/RRNg6hRLpd0PvBJYElEbJmWjQIuJ8trjwOfi4iK364K9xUrjDU1Kj6sXeodhuXwm6furncIltM66z0zKyIm9Xb/EWutF1vvcHhV206/4egej5U6SywHLipL5j8Bno+IEyUdA6wdEf9e6VgN1zXRzKyuAihFdVOloiKmk/W2K/cZshsjSX+ramlolGYWM7PiqL5BYx1JM8teT0092HoytqMXXkQsllRVrwEnczOznHL0Zln6Tpp08nAyNzPLqY97szwraVyqlY8j64JdkdvMzczyiBxT71wHHJDmDwB+V81OrpmbmeWQ3TRUm5q5pMuAncja1v9ONoDgicAVkg4EngT2q6YsJ3Mzs7xqNCJiRHyxm1W5+yg7mZuZ5VSrmnktOZmbmeXhJw2ZmTUDj81iZtYc3MxiZtbgopjPAHUyNzPLyzVzM7MmULxc7mRuZpaXSsVrZ3EyNzPLI6jZTUO15GRuZpaDCN80ZGbWFJzMzcyagJO5mVmDc5u5mVlzcG8WM7OGF25mMTNreIGTuZlZUyheK4uTuZlZXu5nbmbWDJzMzcwaXAS0F6+dxcnczCwv18zNzJqAk7mZWYMLwM8ANTNrdAHhNnMzs8YW+AKomVlTqGGbuaTHgVeAdqAtIib1phwnczOzvGp/AXTniFj6TgpwMjczy6WYA2211DsAM7OGEkCpVN0E60iaWTYd1E2JN0ua1c36qrhmbmaWV/U186VVtIFvHxHPSBoD3CLpwYiYnjck18zNzHJJt/NXM1VTWsQz6e8S4Bpgu95E5WRuZpZHQESpqqkSScMkjeiYB3YDHuhNWG5mMTPLq3Z3gI4FrpEEWT6+NCJu6k1BTuZmZnnVqDdLRDwKTKhFWU7mZmZ5RHT0VCkUJ3Mzs7wK2M/cydzMLJcg2tvrHcQ/cTI3M8vDQ+CamTUJD4FrZtbYAgjXzM3MGlz44RRmZk2hiBdAFQXsYjNQSXoOeKLecfSBdYB3NFaz9btm/sw2jIjRvd1Z0k1k56caSyNij94eKw8nc+tzkmb29ukpVh/+zBqPB9oyM2sCTuZmZk3Aydz6w9R6B2C5+TNrMG4zNzNrAq6Zm5k1ASdzM7Mm4GTexCSFpFPKXh8p6bgalX2BpH1rUVYvjr23pC16sd/yvoinv0lqlzRH0nxJ90v6rqS6/L8saS1J3+rFfsdJOrIvYhqonMyb2wrgs5KqvcGhUewN5E7mTeT1iJgYER8APgHsBRxbp1jWAnInc6s9J/Pm1kbWK+E7nVdI2lDSbZLmpr8bpOUXSPq5pLslPdpR+1bmdEkLJN0IjCkraxtJf5I0S9I0SePS8jsknSZpuqSFkraVdLWkhyX9V9n+X5J0b6ptniOpNS1fLumEVPv8i6Sxkj4KfBo4KW2/aZpuSse/U9Lmaf+NJd0jaYakH/XZWa6j9ET3g4BD02fUKumk9J7nSjoYQNJO6TO6QtJDkk6UtH867/MkbZq2Gy3pqrT/DEnbp+XHSTo/faaPSpqSQjgR2DR9FielbY8qO/7xHbFK+r6kRZJuBcb342kaGCLCU5NOwHJgTeBxYCRwJHBcWnc9cECa/xpwbZq/ALiS7It+C+CRtPyzwC1AK7Au8CKwLzAYuBsYnbb7PHB+mr8D+HGaPxx4BhgHrA78HXgX8P4Uy+C03ZnAV9J8AJ9K8z8BflAW475l7/M2YLM0/2Hg9jR/XVlZhwDL6/2Z1Opz7WLZC2QPBz6o7DytDswENgZ2Sp9Zx/l/Gji+7LP5aZq/FPhYmt8AWJjmj0uf8+pkt7IvS5/9RsADZXHsRlaBUPo3dAOwI7ANMA9YI/2bfAQ4st7nspkmD7TV5CLiZUkXAVOA18tWTSZL0AAXkyXLDtdGRAlYIGlsWrYjcFlEtAPPSLo9LR8PbAnckp4w3gosLivruvR3HjA/IhYDSHoUWB/4GNn/6DPS/kOBJWmflWTJAGAWWZPC20gaDnwUuDLtD1nCAdge+L9l7/HHnfdvIh1vfjfgQ2XXM0YCm5Gdyxll5/9vwM1pm3nAzml+V2CLsnO5pqQRaf7GiFgBrJC0hOzLo7Pd0nRfej08HX8EcE1EvJaOf10X+9o74GQ+MPwUmA38qodtym84WFE2r262KV8/PyImd1NuR1mlTuWWyP79CbgwIr7Xxb6rIlX3gHa6/vfaArwYERO7OX7T30ghaROy87OE7HweFhHTOm2zE/98/ss/m45z2wJMjojyL35Sci/fv7vPQ8D/RMQ5nfb/NgPgs6gnt5kPABHxPHAFcGDZ4ruBL6T5/YG7KhQzHfhCapMdx1s1uUXAaEmTASQNlvSBHOHdBuwraUzaf5SkDSvs8wpZTY+IeBl4TNJ+aX9JmpC2+zNvf49NR9Jo4Gzg9PTFNw34pqTBaf37JA3LUeTNwKFl5Xf3Jdnhzc8imQZ8Lf1iQtJ70mc7HdhH0tBU0/9UjpisCk7mA8cpvH3YzinAVyXNBb5M1m7ak2uAh8l+kp8F/AkgIlaStZ3/WNL9wByyZo+qRMQC4AfAzSmWW8jadXvyG+AoSfelC3f7Awem488HPpO2Oxw4RNIMsuaGZjE0XXCcD9xKloA7LjT+ElgAzJb0AHAO+X6BTwEmpYuXC4Bv9LRxRCwD/izpAUknRcTNZO3u90iaB/wWGBERs4HLyf59XAXcmSMmq4Jv5zczawKumZuZNQEnczOzJuBkbmbWBJzMzcyagJO5mVkTcDK3hqK3Rgx8QNKVktZ4B2W9OfKjpF+qh5EY09gmVXe5LNvvcXUx0Fl3yzttk2uUR3kkwgHNydwaTceIgVuS3aL+tn7QSoN05RUR/y/1ee/OTuToP2/W35zMrZHdCbw31Zr/KOlSYF4PIwdK3Y/8eIekSWl+D0mzlY3WeJukjci+NL6TfhXs0MPogu+SdHO6oekc3j4cQpckXatsxMf5kg7qtO6UFMtt6W5P1M0okTaweWwWa0iSBgF7AjelRdsBW0bEYykhvhQR20panewOxZuBrcgGBvsg2SBRC4DzO5U7GjgX2DGVNSoinpd0NtlohSen7S4FTouIu5QNHzyNbATIY4G7IuKHkv6FbBTDSr6WjjGUbMCxq9KdlcOA2RFxhKT/TGUfSjYq4Tci4mFJHyYbafLjvTiN1kSczK3RDJU0J83fCZxH1vxxb0Q8lpZ3N3JgdyM/lvsIML2jrDSuTVe6G11wR9JolBFxo6QXqnhPUyTtk+bXT7EuIxsA6/K0/NfA1ep5lEgbwJzMrdG83nmExJTUXi1fRNcjB+5F5ZH7VMU20PPoglWPkaFsNMNdU1mvSboDGNLN5kHlUSJtgHKbuTWj7kYO7G7kx3L3AP9H0sZp31FpeefRAbsbXXA6aYRGSXsCa1eIdSTwQkrkm5P9MujQQjaIGcC/kjXf9DRKpA1gTubWjLobObDLkR/LRcRzZO3cV6dRGDuaOa4nG8J1jqQd6H50weOBHSXNJmvuebJCrDcBg5SNGPkj4C9l614FPiBpFlmb+A/T8u5GibQBzKMmmpk1AdfMzcyagJO5mVkTcDI3M2sCTuZmZk3AydzMrAT0TJMAAAASSURBVAk4mZuZNQEnczOzJvC/An0+zo7Ryk4AAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "clf_svm = SVC(random_state=42, C=C, gamma=gamma, kernel=kernel)\n", + "clf_svm.fit(X_trainval_scaled, y_trainval)\n", + "\n", + "plot_confusion_matrix(clf_svm, \n", + " X_test_scaled, \n", + " y_test, \n", + " values_format='d', \n", + " display_labels=['Nondemented', 'Demented'])" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:25:05.517061Z", + "iopub.status.busy": "2021-02-21T05:25:05.516041Z", + "iopub.status.idle": "2021-02-21T05:25:05.535028Z", + "shell.execute_reply": "2021-02-21T05:25:05.534191Z" + }, + "papermill": { + "duration": 0.08746, + "end_time": "2021-02-21T05:25:05.535192", + "exception": false, + "start_time": "2021-02-21T05:25:05.447732", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train accuracy 1.0\n", + "Test accuracy 0.92\n", + "Test recall 0.918918918918919\n", + "Test AUC 0.9199857752489331\n" + ] + } + ], + "source": [ + "train_score = 0\n", + "test_score = 0\n", + "test_recall = 0\n", + "test_auc = 0\n", + "\n", + "train_score = clf_svm.score(X_trainval_scaled, y_trainval)\n", + "test_score = clf_svm.score(X_test_scaled, y_test)\n", + "y_predict = clf_svm.predict(X_test_scaled)\n", + "\n", + "test_recall = recall_score(y_test, y_predict)\n", + "svm_fpr, svm_tpr, thresholds = roc_curve(y_test, y_predict)\n", + "test_auc = auc(svm_fpr, svm_tpr)\n", + "\n", + "\n", + "print(\"Train accuracy \", train_score)\n", + "print(\"Test accuracy \", test_score)\n", + "print(\"Test recall\", test_recall)\n", + "print(\"Test AUC\", test_auc)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.064261, + "end_time": "2021-02-21T05:25:05.665444", + "exception": false, + "start_time": "2021-02-21T05:25:05.601183", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Random Forest" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:25:05.798235Z", + "iopub.status.busy": "2021-02-21T05:25:05.797535Z", + "iopub.status.idle": "2021-02-21T05:25:05.857124Z", + "shell.execute_reply": "2021-02-21T05:25:05.856511Z" + }, + "papermill": { + "duration": 0.126931, + "end_time": "2021-02-21T05:25:05.857258", + "exception": false, + "start_time": "2021-02-21T05:25:05.730327", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "from sklearn.ensemble import RandomForestClassifier" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:25:06.011800Z", + "iopub.status.busy": "2021-02-21T05:25:06.001387Z", + "iopub.status.idle": "2021-02-21T05:25:06.381618Z", + "shell.execute_reply": "2021-02-21T05:25:06.381041Z" + }, + "papermill": { + "duration": 0.459083, + "end_time": "2021-02-21T05:25:06.381735", + "exception": false, + "start_time": "2021-02-21T05:25:05.922652", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# n_estimators(M) --> the number of trees in the forest\n", + "# max_features(d) --> the number of features to consider when looking for the best split\n", + "# max_depth(m) --> the maximum depth of the tree.\n", + "\n", + "rfc = RandomForestClassifier(random_state=42)\n", + "rfc.fit(X_trainval_scaled, y_trainval)\n", + "\n", + "# for test there are 94 cases\n", + "plot_confusion_matrix(rfc, \n", + " X_test_scaled, \n", + " y_test, \n", + " values_format='d', \n", + " display_labels=['Nondemented', 'Demented'])" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:25:06.524845Z", + "iopub.status.busy": "2021-02-21T05:25:06.524139Z", + "iopub.status.idle": "2021-02-21T05:25:06.567152Z", + "shell.execute_reply": "2021-02-21T05:25:06.566168Z" + }, + "papermill": { + "duration": 0.119008, + "end_time": "2021-02-21T05:25:06.567316", + "exception": false, + "start_time": "2021-02-21T05:25:06.448308", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train accuracy 1.0\n", + "Test accuracy 0.8133333333333334\n", + "Test recall 0.7027027027027027\n", + "Test AUC 0.811877667140825\n" + ] + } + ], + "source": [ + "train_score = 0\n", + "test_score = 0\n", + "test_recall = 0\n", + "test_auc = 0\n", + "\n", + "train_score = rfc.score(X_trainval_scaled, y_trainval)\n", + "test_score = rfc.score(X_test_scaled, y_test)\n", + "y_predict = rfc.predict(X_test_scaled)\n", + "test_recall = recall_score(y_test, y_predict)\n", + "fpr, tpr, thresholds = roc_curve(y_test, y_predict)\n", + "test_auc = auc(fpr, tpr)\n", + "\n", + "print(\"Train accuracy \", train_score)\n", + "print(\"Test accuracy \", test_score)\n", + "print(\"Test recall\", test_recall)\n", + "print(\"Test AUC\", test_auc)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.066938, + "end_time": "2021-02-21T05:25:06.702778", + "exception": false, + "start_time": "2021-02-21T05:25:06.635840", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### Optimize parameters(Finetuning) --> GridSearchCV()" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:25:06.846366Z", + "iopub.status.busy": "2021-02-21T05:25:06.845356Z", + "iopub.status.idle": "2021-02-21T05:25:06.848964Z", + "shell.execute_reply": "2021-02-21T05:25:06.848404Z" + }, + "papermill": { + "duration": 0.078921, + "end_time": "2021-02-21T05:25:06.849108", + "exception": false, + "start_time": "2021-02-21T05:25:06.770187", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# Number of trees in random forest\n", + "n_estimators = [int(x) for x in np.linspace(start = 10, stop = 100, num = 10)]\n", + "\n", + "# Number of features to consider at every split\n", + "max_features = ['auto', 'sqrt', 'log2']\n", + "\n", + "# Maximum number of levels in tree\n", + "max_depth = range(1,10)\n", + "\n", + "# measure the quality of a split\n", + "criterion = ['gini']\n", + "\n", + "# Method of selecting samples for training each tree\n", + "bootstrap = [True, False]" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:25:06.987850Z", + "iopub.status.busy": "2021-02-21T05:25:06.987210Z", + "iopub.status.idle": "2021-02-21T05:25:06.992272Z", + "shell.execute_reply": "2021-02-21T05:25:06.991634Z" + }, + "papermill": { + "duration": 0.07594, + "end_time": "2021-02-21T05:25:06.992384", + "exception": false, + "start_time": "2021-02-21T05:25:06.916444", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# Create the param grid\n", + "param_grid = {'n_estimators': n_estimators,\n", + " 'max_features': max_features,\n", + " 'max_depth': max_depth,\n", + " 'criterion': criterion,\n", + " 'bootstrap': bootstrap}" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:25:07.136382Z", + "iopub.status.busy": "2021-02-21T05:25:07.135353Z", + "iopub.status.idle": "2021-02-21T05:26:57.959858Z", + "shell.execute_reply": "2021-02-21T05:26:57.960433Z" + }, + "papermill": { + "duration": 110.901375, + "end_time": "2021-02-21T05:26:57.960599", + "exception": false, + "start_time": "2021-02-21T05:25:07.059224", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'bootstrap': False, 'criterion': 'gini', 'max_depth': 8, 'max_features': 'auto', 'n_estimators': 60}\n" + ] + } + ], + "source": [ + "optimal_params = GridSearchCV(RandomForestClassifier(),\n", + " param_grid,\n", + " cv=5, # we are taking 5-fold as in k-fold cross validation\n", + " scoring='accuracy', # try the other scoring if have time\n", + " verbose=0,\n", + " n_jobs=-1)\n", + "\n", + "optimal_params.fit(X_trainval_scaled, y_trainval)\n", + "print(optimal_params.best_params_)" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:26:58.102512Z", + "iopub.status.busy": "2021-02-21T05:26:58.101661Z", + "iopub.status.idle": "2021-02-21T05:26:58.104540Z", + "shell.execute_reply": "2021-02-21T05:26:58.104022Z" + }, + "papermill": { + "duration": 0.076703, + "end_time": "2021-02-21T05:26:58.104654", + "exception": false, + "start_time": "2021-02-21T05:26:58.027951", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "bootstrap = optimal_params.best_params_['bootstrap']\n", + "criterion = optimal_params.best_params_['criterion']\n", + "max_depth = optimal_params.best_params_['max_depth']\n", + "max_features = optimal_params.best_params_['max_features']\n", + "n_estimators = optimal_params.best_params_['n_estimators']" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:26:58.264473Z", + "iopub.status.busy": "2021-02-21T05:26:58.263365Z", + "iopub.status.idle": "2021-02-21T05:26:58.573331Z", + "shell.execute_reply": "2021-02-21T05:26:58.572415Z" + }, + "papermill": { + "duration": 0.401085, + "end_time": "2021-02-21T05:26:58.573516", + "exception": false, + "start_time": "2021-02-21T05:26:58.172431", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "rfc = RandomForestClassifier(n_estimators=n_estimators, \n", + " max_features=max_features, \n", + " max_depth=max_depth, \n", + " criterion=criterion,\n", + " bootstrap=bootstrap,\n", + " random_state=42)\n", + "\n", + "rfc.fit(X_trainval_scaled, y_trainval)\n", + "\n", + "# for test there are 94 cases\n", + "plot_confusion_matrix(rfc, \n", + " X_test_scaled, \n", + " y_test, \n", + " values_format='d', \n", + " display_labels=['Nondemented', 'Demented'])" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:26:58.733907Z", + "iopub.status.busy": "2021-02-21T05:26:58.732790Z", + "iopub.status.idle": "2021-02-21T05:26:58.764470Z", + "shell.execute_reply": "2021-02-21T05:26:58.765236Z" + }, + "papermill": { + "duration": 0.111182, + "end_time": "2021-02-21T05:26:58.765445", + "exception": false, + "start_time": "2021-02-21T05:26:58.654263", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train accuracy 0.9966442953020134\n", + "Test accuracy 0.7866666666666666\n", + "Test recall 0.7027027027027027\n", + "Test AUC 0.7855618776671409\n" + ] + } + ], + "source": [ + "train_score = 0\n", + "test_score = 0\n", + "test_recall = 0\n", + "test_auc = 0\n", + "\n", + "train_score = rfc.score(X_trainval_scaled, y_trainval)\n", + "test_score = rfc.score(X_test_scaled, y_test)\n", + "y_predict = rfc.predict(X_test_scaled)\n", + "test_recall = recall_score(y_test, y_predict)\n", + "rfc_fpr, rfc_tpr, thresholds = roc_curve(y_test, y_predict)\n", + "test_auc = auc(rfc_fpr, rfc_tpr)\n", + "\n", + "print(\"Train accuracy \", train_score)\n", + "print(\"Test accuracy \", test_score)\n", + "print(\"Test recall\", test_recall)\n", + "print(\"Test AUC\", test_auc)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.071036, + "end_time": "2021-02-21T05:26:58.907460", + "exception": false, + "start_time": "2021-02-21T05:26:58.836424", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Logistic Regression" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:26:59.053932Z", + "iopub.status.busy": "2021-02-21T05:26:59.052923Z", + "iopub.status.idle": "2021-02-21T05:26:59.056562Z", + "shell.execute_reply": "2021-02-21T05:26:59.055881Z" + }, + "papermill": { + "duration": 0.078846, + "end_time": "2021-02-21T05:26:59.056676", + "exception": false, + "start_time": "2021-02-21T05:26:58.977830", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "from sklearn.linear_model import LogisticRegression\n", + "from sklearn.tree import DecisionTreeClassifier\n", + "from sklearn.metrics import confusion_matrix, precision_score, accuracy_score, recall_score, roc_curve, auc" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:26:59.210264Z", + "iopub.status.busy": "2021-02-21T05:26:59.209530Z", + "iopub.status.idle": "2021-02-21T05:26:59.392907Z", + "shell.execute_reply": "2021-02-21T05:26:59.392347Z" + }, + "papermill": { + "duration": 0.263735, + "end_time": "2021-02-21T05:26:59.393060", + "exception": false, + "start_time": "2021-02-21T05:26:59.129325", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "log_reg_model = LogisticRegression().fit(X_trainval_scaled, y_trainval)\n", + "\n", + "# for test there are 94 cases\n", + "plot_confusion_matrix(log_reg_model, \n", + " X_test_scaled, \n", + " y_test, \n", + " values_format='d', \n", + " display_labels=['Nondemented', 'Demented'])" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:26:59.553251Z", + "iopub.status.busy": "2021-02-21T05:26:59.545050Z", + "iopub.status.idle": "2021-02-21T05:26:59.562200Z", + "shell.execute_reply": "2021-02-21T05:26:59.562678Z" + }, + "papermill": { + "duration": 0.098292, + "end_time": "2021-02-21T05:26:59.562830", + "exception": false, + "start_time": "2021-02-21T05:26:59.464538", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train accuracy 0.8221476510067114\n", + "Test accuracy 0.7466666666666667\n", + "Test recall 0.7027027027027027\n", + "Test AUC 0.7460881934566145\n" + ] + } + ], + "source": [ + "train_score = 0\n", + "test_score = 0\n", + "test_recall = 0\n", + "test_auc = 0\n", + "\n", + "log_reg_model = LogisticRegression().fit(X_trainval_scaled, y_trainval)\n", + "train_score = log_reg_model.score(X_trainval_scaled, y_trainval)\n", + "test_score = log_reg_model.score(X_test_scaled, y_test)\n", + "scores = log_reg_model.score(X_test_scaled, y_test)\n", + "y_predict = log_reg_model.predict(X_test_scaled)\n", + "test_recall = recall_score(y_test, y_predict)\n", + "fpr, tpr, thresholds = roc_curve(y_test, y_predict)\n", + "test_auc = auc(fpr, tpr)\n", + "\n", + "\n", + "print(\"Train accuracy \", train_score)\n", + "print(\"Test accuracy \", test_score)\n", + "print(\"Test recall\", test_recall)\n", + "print(\"Test AUC\", test_auc)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.072092, + "end_time": "2021-02-21T05:26:59.707748", + "exception": false, + "start_time": "2021-02-21T05:26:59.635656", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### Optimize parameters(Finetuning) --> GridSearchCV()" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:26:59.859870Z", + "iopub.status.busy": "2021-02-21T05:26:59.859194Z", + "iopub.status.idle": "2021-02-21T05:27:00.087736Z", + "shell.execute_reply": "2021-02-21T05:27:00.088576Z" + }, + "papermill": { + "duration": 0.308978, + "end_time": "2021-02-21T05:27:00.088785", + "exception": false, + "start_time": "2021-02-21T05:26:59.779807", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'C': 2, 'penalty': 'l2'}\n" + ] + } + ], + "source": [ + "param_grid = {'penalty': ['l1','l2'], \n", + " 'C': [0.001,0.01,0.1,1, 2, 3, 5, 10,100,1000]}\n", + "\n", + "optimal_params = GridSearchCV(LogisticRegression(),\n", + " param_grid,\n", + " cv=5, # we are taking 5-fold as in k-fold cross validation\n", + " scoring='accuracy', # try the other scoring if have time\n", + " verbose=0,\n", + " n_jobs=-1)\n", + "\n", + "optimal_params.fit(X_trainval_scaled, y_trainval)\n", + "print(optimal_params.best_params_)" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:27:00.240392Z", + "iopub.status.busy": "2021-02-21T05:27:00.239384Z", + "iopub.status.idle": "2021-02-21T05:27:00.243000Z", + "shell.execute_reply": "2021-02-21T05:27:00.242454Z" + }, + "papermill": { + "duration": 0.080874, + "end_time": "2021-02-21T05:27:00.243134", + "exception": false, + "start_time": "2021-02-21T05:27:00.162260", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# best_score = -10\n", + "# for c in range(1, 20): \n", + "# log_reg_model = LogisticRegression(C=c)\n", + "# scores = cross_val_score(log_reg_model, X_trainval_scaled, y_trainval, cv=5, scoring='accuracy')\n", + " \n", + "# mean_score = scores.mean()\n", + " \n", + "# if mean_score > best_score:\n", + "# best_score = mean_score\n", + "# best_c = c\n", + "# print(best_c)" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:27:00.393338Z", + "iopub.status.busy": "2021-02-21T05:27:00.392646Z", + "iopub.status.idle": "2021-02-21T05:27:00.394698Z", + "shell.execute_reply": "2021-02-21T05:27:00.395162Z" + }, + "papermill": { + "duration": 0.079368, + "end_time": "2021-02-21T05:27:00.395308", + "exception": false, + "start_time": "2021-02-21T05:27:00.315940", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "best_C = optimal_params.best_params_['C']\n", + "best_penalty = optimal_params.best_params_['penalty']" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:27:00.543733Z", + "iopub.status.busy": "2021-02-21T05:27:00.543072Z", + "iopub.status.idle": "2021-02-21T05:27:00.726618Z", + "shell.execute_reply": "2021-02-21T05:27:00.725951Z" + }, + "papermill": { + "duration": 0.258901, + "end_time": "2021-02-21T05:27:00.726760", + "exception": false, + "start_time": "2021-02-21T05:27:00.467859", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "log_reg_model = LogisticRegression(C=best_C, penalty=best_penalty).fit(X_trainval_scaled, y_trainval)\n", + "\n", + "# for test there are 94 cases\n", + "plot_confusion_matrix(log_reg_model, \n", + " X_test_scaled, \n", + " y_test, \n", + " values_format='d', \n", + " display_labels=['Nondemented', 'Demented'])" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:27:00.902586Z", + "iopub.status.busy": "2021-02-21T05:27:00.887719Z", + "iopub.status.idle": "2021-02-21T05:27:00.911752Z", + "shell.execute_reply": "2021-02-21T05:27:00.912556Z" + }, + "papermill": { + "duration": 0.108288, + "end_time": "2021-02-21T05:27:00.912736", + "exception": false, + "start_time": "2021-02-21T05:27:00.804448", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train accuracy with Logistec regression: 0.8288590604026845\n", + "Test accuracy with Logistec regression: 0.7466666666666667\n", + "Test recall with Logistec regression: 0.7027027027027027\n", + "Test AUC with Logistec regression: 0.7460881934566145\n" + ] + } + ], + "source": [ + "train_score = 0\n", + "test_score = 0\n", + "test_recall = 0\n", + "test_auc = 0\n", + "\n", + "best_log_reg_model = LogisticRegression(C=best_C, penalty=best_penalty).fit(X_trainval_scaled, y_trainval)\n", + "train_score = best_log_reg_model.score(X_trainval_scaled, y_trainval)\n", + "test_score = best_log_reg_model.score(X_test_scaled, y_test)\n", + "y_predict = best_log_reg_model.predict(X_test_scaled)\n", + "test_recall = recall_score(y_test, y_predict)\n", + "lgr_fpr, lgr_tpr, thresholds = roc_curve(y_test, y_predict)\n", + "test_auc = auc(lgr_fpr, lgr_tpr)\n", + "\n", + "print(\"Train accuracy with Logistec regression:\", train_score)\n", + "print(\"Test accuracy with Logistec regression:\", test_score)\n", + "print(\"Test recall with Logistec regression:\", test_recall)\n", + "print(\"Test AUC with Logistec regression:\", test_auc)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.074221, + "end_time": "2021-02-21T05:27:01.064257", + "exception": false, + "start_time": "2021-02-21T05:27:00.990036", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "# Decision Tree" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:27:01.218547Z", + "iopub.status.busy": "2021-02-21T05:27:01.217821Z", + "iopub.status.idle": "2021-02-21T05:27:01.391666Z", + "shell.execute_reply": "2021-02-21T05:27:01.391137Z" + }, + "papermill": { + "duration": 0.253607, + "end_time": "2021-02-21T05:27:01.391785", + "exception": false, + "start_time": "2021-02-21T05:27:01.138178", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 48, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "dt_model = DecisionTreeClassifier().fit(X_trainval_scaled, y_trainval)\n", + "\n", + "# for test there are 94 cases\n", + "plot_confusion_matrix(dt_model, \n", + " X_test_scaled, \n", + " y_test, \n", + " values_format='d', \n", + " display_labels=['Nondemented', 'Demented'])" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:27:01.552124Z", + "iopub.status.busy": "2021-02-21T05:27:01.551400Z", + "iopub.status.idle": "2021-02-21T05:27:01.561205Z", + "shell.execute_reply": "2021-02-21T05:27:01.560544Z" + }, + "papermill": { + "duration": 0.094559, + "end_time": "2021-02-21T05:27:01.561320", + "exception": false, + "start_time": "2021-02-21T05:27:01.466761", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train accuracy with DecisionTreeClassifier: 1.0\n", + "Test accuracy with DecisionTreeClassifier: 0.72\n", + "Test recall with DecisionTreeClassifier: 0.6756756756756757\n", + "Test AUC with DecisionTreeClassifier: 0.719416785206259\n" + ] + } + ], + "source": [ + "train_score = 0\n", + "test_score = 0\n", + "test_recall = 0\n", + "test_auc = 0\n", + "\n", + "dt_model = DecisionTreeClassifier().fit(X_trainval_scaled, y_trainval)\n", + "train_score = dt_model.score(X_trainval_scaled, y_trainval)\n", + "test_score = dt_model.score(X_test_scaled, y_test)\n", + "y_predict = dt_model.predict(X_test_scaled)\n", + "test_recall = recall_score(y_test, y_predict)\n", + "fpr, tpr, thresholds = roc_curve(y_test, y_predict)\n", + "test_auc = auc(fpr, tpr)\n", + "\n", + "print(\"Train accuracy with DecisionTreeClassifier:\", train_score)\n", + "print(\"Test accuracy with DecisionTreeClassifier:\", test_score)\n", + "print(\"Test recall with DecisionTreeClassifier:\", test_recall)\n", + "print(\"Test AUC with DecisionTreeClassifier:\", test_auc)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.075289, + "end_time": "2021-02-21T05:27:01.712002", + "exception": false, + "start_time": "2021-02-21T05:27:01.636713", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "### Optimize parameters(Finetuning) --> GridSearchCV()" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:27:01.872535Z", + "iopub.status.busy": "2021-02-21T05:27:01.871460Z", + "iopub.status.idle": "2021-02-21T05:27:01.957853Z", + "shell.execute_reply": "2021-02-21T05:27:01.957187Z" + }, + "papermill": { + "duration": 0.169737, + "end_time": "2021-02-21T05:27:01.957989", + "exception": false, + "start_time": "2021-02-21T05:27:01.788252", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'criterion': 'gini', 'max_depth': 2}\n" + ] + } + ], + "source": [ + "param_grid = {'criterion': ['gini'], \n", + " 'max_depth': range(1,10)}\n", + "\n", + "optimal_params = GridSearchCV(DecisionTreeClassifier(),\n", + " param_grid,\n", + " cv=5, # we are taking 5-fold as in k-fold cross validation\n", + " scoring='accuracy', # try the other scoring if have time\n", + " verbose=0,\n", + " n_jobs=-1)\n", + "\n", + "optimal_params.fit(X_trainval_scaled, y_trainval)\n", + "print(optimal_params.best_params_)" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:27:02.116441Z", + "iopub.status.busy": "2021-02-21T05:27:02.115645Z", + "iopub.status.idle": "2021-02-21T05:27:02.118675Z", + "shell.execute_reply": "2021-02-21T05:27:02.118177Z" + }, + "papermill": { + "duration": 0.085294, + "end_time": "2021-02-21T05:27:02.118786", + "exception": false, + "start_time": "2021-02-21T05:27:02.033492", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "criterion = optimal_params.best_params_['criterion']\n", + "max_depth = optimal_params.best_params_['max_depth']" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:27:02.276080Z", + "iopub.status.busy": "2021-02-21T05:27:02.275361Z", + "iopub.status.idle": "2021-02-21T05:27:02.278765Z", + "shell.execute_reply": "2021-02-21T05:27:02.278075Z" + }, + "papermill": { + "duration": 0.08369, + "end_time": "2021-02-21T05:27:02.278885", + "exception": false, + "start_time": "2021-02-21T05:27:02.195195", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# best_score = -1\n", + "# for d in range(1, 25): \n", + "# dt_model = DecisionTreeClassifier(max_depth = d)\n", + "# scores = cross_val_score(dt_model, X_trainval_scaled, y_trainval, cv=5, scoring='accuracy')\n", + " \n", + "# mean_score = scores.mean()\n", + " \n", + "# if mean_score > best_score:\n", + "# best_score = mean_score\n", + "# best_d = d\n", + "# print(best_d)" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:27:02.439628Z", + "iopub.status.busy": "2021-02-21T05:27:02.438933Z", + "iopub.status.idle": "2021-02-21T05:27:02.607174Z", + "shell.execute_reply": "2021-02-21T05:27:02.606465Z" + }, + "papermill": { + "duration": 0.251238, + "end_time": "2021-02-21T05:27:02.607295", + "exception": false, + "start_time": "2021-02-21T05:27:02.356057", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "dt_model = DecisionTreeClassifier(criterion=criterion, max_depth=max_depth).fit(X_trainval_scaled, y_trainval)\n", + "\n", + "# for test there are 94 cases\n", + "plot_confusion_matrix(dt_model, \n", + " X_test_scaled, \n", + " y_test, \n", + " values_format='d', \n", + " display_labels=['Nondemented', 'Demented'])" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:27:02.771744Z", + "iopub.status.busy": "2021-02-21T05:27:02.771062Z", + "iopub.status.idle": "2021-02-21T05:27:02.780168Z", + "shell.execute_reply": "2021-02-21T05:27:02.780646Z" + }, + "papermill": { + "duration": 0.096343, + "end_time": "2021-02-21T05:27:02.780797", + "exception": false, + "start_time": "2021-02-21T05:27:02.684454", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train accuracy with DecisionTreeClassifier: 0.7751677852348994\n", + "Test accuracy with DecisionTreeClassifier: 0.8\n", + "Test recall with DecisionTreeClassifier: 0.5945945945945946\n", + "Test AUC with DecisionTreeClassifier: 0.7972972972972974\n" + ] + } + ], + "source": [ + "train_score = 0\n", + "test_score = 0\n", + "test_recall = 0\n", + "test_auc = 0\n", + "\n", + "dt_model = DecisionTreeClassifier(criterion=criterion, max_depth=max_depth).fit(X_trainval_scaled, y_trainval)\n", + "train_score = dt_model.score(X_trainval_scaled, y_trainval)\n", + "test_score = dt_model.score(X_test_scaled, y_test)\n", + "y_predict = dt_model.predict(X_test_scaled)\n", + "test_recall = recall_score(y_test, y_predict)\n", + "dt_fpr, dt_tpr, thresholds = roc_curve(y_test, y_predict)\n", + "test_auc = auc(dt_fpr, dt_tpr)\n", + "\n", + "print(\"Train accuracy with DecisionTreeClassifier:\", train_score)\n", + "print(\"Test accuracy with DecisionTreeClassifier:\", test_score)\n", + "print(\"Test recall with DecisionTreeClassifier:\", test_recall)\n", + "print(\"Test AUC with DecisionTreeClassifier:\", test_auc)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.079404, + "end_time": "2021-02-21T05:27:02.938142", + "exception": false, + "start_time": "2021-02-21T05:27:02.858738", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "## Plot ROC and compare AUC" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": { + "execution": { + "iopub.execute_input": "2021-02-21T05:27:03.110051Z", + "iopub.status.busy": "2021-02-21T05:27:03.109234Z", + "iopub.status.idle": "2021-02-21T05:27:03.296587Z", + "shell.execute_reply": "2021-02-21T05:27:03.295893Z" + }, + "papermill": { + "duration": 0.279336, + "end_time": "2021-02-21T05:27:03.296703", + "exception": false, + "start_time": "2021-02-21T05:27:03.017367", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.figure(figsize=(5, 5), dpi=100)\n", + "plt.plot(svm_fpr, svm_tpr, linestyle='-', label='SVM')\n", + "plt.plot(lgr_fpr, lgr_tpr, marker='.', label='Logistic')\n", + "plt.plot(rfc_fpr, rfc_tpr, linestyle=':', label='Random Forest')\n", + "plt.plot(dt_fpr, dt_tpr, linestyle='-.', label='Decision Tree')\n", + "\n", + "plt.xlabel('False Positive Rate')\n", + "plt.ylabel('True Positive Rate')\n", + "\n", + "plt.legend()\n", + "\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + }, + "papermill": { + "duration": 147.570544, + "end_time": "2021-02-21T05:27:03.491568", + "environment_variables": {}, + "exception": null, + "input_path": "__notebook__.ipynb", + "output_path": "__notebook__.ipynb", + "parameters": {}, + "start_time": "2021-02-21T05:24:35.921024", + "version": "2.1.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}