diff --git a/subjects/ai/model-selection/README.md b/subjects/ai/model-selection/README.md index ca6702db5a..57bfa8caa7 100644 --- a/subjects/ai/model-selection/README.md +++ b/subjects/ai/model-selection/README.md @@ -33,13 +33,13 @@ I suggest to use the most recent one. ### Biais-Variance trade off, aka Underfitting/Overfitting: -- https://machinelearningmastery.com/gentle-introduction-to-the-bias-variance-trade-off-in-machine-learning/ +- [Bias-Variance Trade-Off in Machine Learning](https://machinelearningmastery.com/gentle-introduction-to-the-bias-variance-trade-off-in-machine-learning/) -- https://jakevdp.github.io/PythonDataScienceHandbook/05.03-hyperparameters-and-model-validation.html +- [Hyperparameters and Model Validation](https://jakevdp.github.io/PythonDataScienceHandbook/05.03-hyperparameters-and-model-validation.html) ### Cross-validation -- https://algotrading101.com/learn/train-test-split/ +- [Train/Test Split and Cross Validation](https://algotrading101.com/learn/train-test-split/) --- @@ -143,11 +143,11 @@ Standard deviation of scores on validation sets: ``` -**Note: It may be confusing that the key of the dictionary that returns the results on the validation sets is `test_score`. Sometimes, the validation sets are called test sets. In that case, we run the cross validation on X_train. It means that the scores are computed on sets in the initial train set. The X_test is not used for the cross-validation.** +> Note: It may be confusing that the key of the dictionary that returns the results on the validation sets is `test_score`. Sometimes, the validation sets are called test sets. In that case, we run the cross validation on X_train. It means that the scores are computed on sets in the initial train set. The X_test is not used for the cross-validation. -- https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_validate.html +- [sklearn.model_selection.cross_validate](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_validate.html) -- https://machinelearningmastery.com/how-to-configure-k-fold-cross-validation/ +- [Configure k-Fold Cross-Validation](https://machinelearningmastery.com/how-to-configure-k-fold-cross-validation/) --- @@ -192,13 +192,13 @@ X_train, X_test, y_train, y_test = train_test_split(X, _Hint_: The name of the metric to put in the parameter `scoring` is `neg_mean_squared_error`. The smaller the MSE is, the better the model is. At the contrary, The greater the R2 is the better the model is. `GridSearchCV` chooses the best model by selecting the one that maximized the score on the validation sets. And, in mathematic, maximizing a function or minimizing its opposite is equivalent. More details: -- https://stackoverflow.com/questions/21443865/scikit-learn-cross-validation-negative-values-with-mean-squared-error +- [scikit-learn cross validation](https://stackoverflow.com/questions/21443865/scikit-learn-cross-validation-negative-values-with-mean-squared-error) 2. Extract the best fitted estimator, print its parameters, its score on the validation set, and display `cv_results_`. 3. Compute the score on the test set. -**WARNING: For classification tasks using AUC score, an error or warning might occur if a fold contains only one class, rendering the AUC unable to be computed due to its definition.** +> **WARNING: For classification tasks using AUC score, an error or warning might occur if a fold contains only one class, rendering the AUC unable to be computed due to its definition.** --- @@ -206,7 +206,7 @@ _Hint_: The name of the metric to put in the parameter `scoring` is `neg_mean_sq # Exercise 4: Validation curve and Learning curve -The goal of this exercise is to learn to analyze the model's performance with two tools: +The goal of this exercise is to learn how to analyze the model's performance with two tools: - Validation curve - Learning curve @@ -229,7 +229,7 @@ X, y = make_classification(n_samples=100000, I do not expect that you implement all the plot from scratch, you'd better leverage the code here: -- https://scikit-learn.org/stable/auto_examples/model_selection/plot_validation_curve +- [Plotting Validation Curves](https://scikit-learn.org/stable/auto_examples/model_selection/plot_validation_curve) The plot should look like this: @@ -239,17 +239,13 @@ The plot should look like this: The interpretation is that from max_depth=10, the train score keeps increasing but the test score (or validation score) reaches a plateau. It means that choosing max_depth = 20 may lead to have an over fitted model. -Note: Given the time computation is is not possible to plot the validation curve for all parameters. It is useful to plot it for parameters that control the over fitting the most. - -More details: - -- https://chrisalbon.com/machine_learning/model_evaluation/plot_the_validation_curve/ +> Note: Given the time computation is is not possible to plot the validation curve for all parameters. It is useful to plot it for parameters that control the over fitting the most. 2. Let us assume the gridsearch returned `clf = RandomForestClassifier(max_depth=12)`. Let's check if the models under fits, over fit or fits correctly. Plot the learning curve. These two resources will help you a lot to understand how to analyze the learning curves and how to plot them: -- https://machinelearningmastery.com/learning-curves-for-diagnosing-machine-learning-model-performance/ +- [Learning Curves to Diagnose Machine Learning Model Performance](https://machinelearningmastery.com/learning-curves-for-diagnosing-machine-learning-model-performance/) -- https://scikit-learn.org/stable/auto_examples/model_selection/plot_learning_curve.html#sphx-glr-auto-examples-model-selection-plot-learning-curve-py +- [Plotting Learning Curves and Checking Models’ Scalability](https://scikit-learn.org/stable/auto_examples/model_selection/plot_learning_curve.html#sphx-glr-auto-examples-model-selection-plot-learning-curve-py) - **Re-use the function in the second resource**, change the cross validation to a classic 10-folds, run the learning curve data computation on all CPUs and plot the three plots as shown below.