-
Notifications
You must be signed in to change notification settings - Fork 0
/
ClassifAI_ 2 - Linear Regression
1 lines (1 loc) · 15.6 KB
/
ClassifAI_ 2 - Linear Regression
1
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"ClassifAI: 2 - Linear Regression","provenance":[{"file_id":"14NGE9UHxFaiLs9XXzKGTIk0NHTbA8yM_","timestamp":1654019517612},{"file_id":"1lkCV4BrFpc7Su5LWsu-l0at15tHAKEtm","timestamp":1653848445956}],"collapsed_sections":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","metadata":{"id":"5c3ofkx48djO"},"source":["### Get dataset and format it"]},{"cell_type":"code","metadata":{"id":"eK6wcn072y9Z","executionInfo":{"status":"ok","timestamp":1660407171485,"user_tz":420,"elapsed":5,"user":{"displayName":"Leo Huang","userId":"16558901284710269921"}}},"source":["import numpy as np\n","import matplotlib.pyplot as plt"],"execution_count":1,"outputs":[]},{"cell_type":"code","metadata":{"id":"Lo9_oBBv7tZ1","executionInfo":{"status":"ok","timestamp":1660407172734,"user_tz":420,"elapsed":1252,"user":{"displayName":"Leo Huang","userId":"16558901284710269921"}}},"source":["from sklearn import datasets"],"execution_count":2,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"dqUgjgE08FQM","executionInfo":{"status":"ok","timestamp":1660407172946,"user_tz":420,"elapsed":215,"user":{"displayName":"Leo Huang","userId":"16558901284710269921"}},"outputId":"e84d56d6-44db-4204-ec0e-e8d1153addbb"},"source":["diabetes_X, diabetes_y = datasets.load_diabetes(return_X_y=True)\n","print(diabetes_X[0]) #print the first data point"],"execution_count":3,"outputs":[{"output_type":"stream","name":"stdout","text":["[ 0.03807591 0.05068012 0.06169621 0.02187235 -0.0442235 -0.03482076\n"," -0.04340085 -0.00259226 0.01990842 -0.01764613]\n"]}]},{"cell_type":"code","metadata":{"id":"3c63AiBa8MQZ","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1660407172947,"user_tz":420,"elapsed":9,"user":{"displayName":"Leo Huang","userId":"16558901284710269921"}},"outputId":"b5586b2a-5267-4920-ec2b-c6aef2e43a62"},"source":["diabetes_X = diabetes_X[:, np.newaxis, 2] #get the third column (bmi), and wrap each element in a list\n","print(diabetes_X[:5]) #show the first five BMIs"],"execution_count":4,"outputs":[{"output_type":"stream","name":"stdout","text":["[[ 0.06169621]\n"," [-0.05147406]\n"," [ 0.04445121]\n"," [-0.01159501]\n"," [-0.03638469]]\n"]}]},{"cell_type":"code","metadata":{"id":"zhypKwZC8NNQ","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1660407172948,"user_tz":420,"elapsed":6,"user":{"displayName":"Leo Huang","userId":"16558901284710269921"}},"outputId":"dd5a6e4a-ec19-40a0-a40c-9b4bc316b27e"},"source":["diabetes_X_train = diabetes_X[:-20]\n","diabetes_X_test = diabetes_X[-20:]\n","\n","diabetes_y_train = diabetes_y[:-20]\n","diabetes_y_test = diabetes_y[-20:]\n","\n","print(diabetes_y[:5]) #show the first five diabetes disease progressions"],"execution_count":5,"outputs":[{"output_type":"stream","name":"stdout","text":["[151. 75. 141. 206. 135.]\n"]}]},{"cell_type":"markdown","metadata":{"id":"-7B_CEwH8Z3v"},"source":["### Train model"]},{"cell_type":"code","metadata":{"id":"t7rO59FN8HlN","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1660407173357,"user_tz":420,"elapsed":413,"user":{"displayName":"Leo Huang","userId":"16558901284710269921"}},"outputId":"7de15682-c0c5-4e43-d3f1-ad95fd7f9288"},"source":["from sklearn import linear_model\n","from sklearn.metrics import mean_squared_error, r2_score\n","regr = linear_model.LinearRegression() # Create linear regression object\n","\n","regr.fit(diabetes_X_train, diabetes_y_train) # Train the model using the training sets"],"execution_count":6,"outputs":[{"output_type":"execute_result","data":{"text/plain":["LinearRegression()"]},"metadata":{},"execution_count":6}]},{"cell_type":"code","metadata":{"id":"GqKdBuV58HQV","executionInfo":{"status":"ok","timestamp":1660407173358,"user_tz":420,"elapsed":14,"user":{"displayName":"Leo Huang","userId":"16558901284710269921"}}},"source":["diabetes_y_pred = regr.predict(diabetes_X_test) # Make predictions for the whole testing dataset"],"execution_count":7,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"8wrT8i2K8Nuw","executionInfo":{"status":"ok","timestamp":1660407173359,"user_tz":420,"elapsed":13,"user":{"displayName":"Leo Huang","userId":"16558901284710269921"}},"outputId":"9917a997-94bc-4c91-8ffd-646bda93974b"},"source":["# The coefficients\n","print('Coefficients: \\n', regr.coef_)\n","# The mean squared error\n","print('Mean squared error: %.2f'\n"," % mean_squared_error(diabetes_y_test, diabetes_y_pred))\n","# The coefficient of determination: 1 is perfect prediction\n","print('Coefficient of determination: %.2f'\n"," % r2_score(diabetes_y_test, diabetes_y_pred))"],"execution_count":8,"outputs":[{"output_type":"stream","name":"stdout","text":["Coefficients: \n"," [938.23786125]\n","Mean squared error: 2548.07\n","Coefficient of determination: 0.47\n"]}]},{"cell_type":"markdown","metadata":{"id":"ixmoQost8q9c"},"source":["### Plot outputs"]},{"cell_type":"code","metadata":{"id":"zr0doBOP8ZBy","executionInfo":{"status":"ok","timestamp":1660407173680,"user_tz":420,"elapsed":329,"user":{"displayName":"Leo Huang","userId":"16558901284710269921"}}},"source":["from sklearn.metrics import mean_squared_error, r2_score"],"execution_count":9,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":262},"id":"-oQG12EcvaDU","executionInfo":{"status":"ok","timestamp":1660407173681,"user_tz":420,"elapsed":9,"user":{"displayName":"Leo Huang","userId":"16558901284710269921"}},"outputId":"9499428e-a174-4768-bd83-38020e36d2e9"},"source":["# Plot outputs\n","plt.scatter(diabetes_X_test, diabetes_y_test, color='black')\n","plt.plot(diabetes_X_test, diabetes_y_pred, color='blue', linewidth=3)\n","\n","plt.xlabel(\"BMI\")\n","plt.ylabel(\"Diabetes disease progressions\")\n","\n","plt.xticks(())\n","plt.yticks(())\n","\n","plt.show()"],"execution_count":10,"outputs":[{"output_type":"display_data","data":{"text/plain":["<Figure size 432x288 with 1 Axes>"],"image/png":"\n"},"metadata":{}}]},{"cell_type":"code","source":[""],"metadata":{"id":"NFPpMOzVP9PK","executionInfo":{"status":"ok","timestamp":1660407173682,"user_tz":420,"elapsed":7,"user":{"displayName":"Leo Huang","userId":"16558901284710269921"}}},"execution_count":10,"outputs":[]}]}