{ "cells": [ { "cell_type": "code", "execution_count": 64, "id": "01c46189-a598-4bfc-9565-a914346decf7", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "from sklearn.model_selection import train_test_split, GridSearchCV\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn.neighbors import KNeighborsClassifier\n", "# from sklearn.metrics import confusion_matrix, roc_curve, roc_auc_score, RocCurveDisplay\n", "from sklearn.metrics import confusion_matrix, recall_score, precision_score, roc_auc_score, roc_curve, accuracy_score, RocCurveDisplay\n", "\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 65, "id": "53f58e2c-169e-434c-9b8e-1bd0cb5f7715", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | Enroll | \n", "Age | \n", "Income | \n", "Hours | \n", "
|---|---|---|---|---|
| 0 | \n", "1 | \n", "26 | \n", "18000 | \n", "14 | \n", "
| 1 | \n", "0 | \n", "43 | \n", "13000 | \n", "9 | \n", "
| 2 | \n", "1 | \n", "55 | \n", "42000 | \n", "16 | \n", "
| 3 | \n", "1 | \n", "55 | \n", "100000 | \n", "13 | \n", "
| 4 | \n", "0 | \n", "55 | \n", "13000 | \n", "12 | \n", "
| \n", " | Enroll | \n", "Age | \n", "Income | \n", "Hours | \n", "
|---|---|---|---|---|
| count | \n", "1000.000000 | \n", "1000.000000 | \n", "1000.000000 | \n", "1000.000000 | \n", "
| mean | \n", "0.403000 | \n", "44.582000 | \n", "68340.000000 | \n", "10.182000 | \n", "
| std | \n", "0.490746 | \n", "13.876737 | \n", "44466.928247 | \n", "4.671263 | \n", "
| min | \n", "0.000000 | \n", "21.000000 | \n", "1000.000000 | \n", "2.000000 | \n", "
| 25% | \n", "0.000000 | \n", "32.000000 | \n", "31000.000000 | \n", "6.000000 | \n", "
| 50% | \n", "0.000000 | \n", "45.000000 | \n", "64000.000000 | \n", "10.000000 | \n", "
| 75% | \n", "1.000000 | \n", "57.000000 | \n", "97000.000000 | \n", "14.000000 | \n", "
| max | \n", "1.000000 | \n", "68.000000 | \n", "198000.000000 | \n", "18.000000 | \n", "
| \n", " | Age | \n", "Income | \n", "Hours | \n", "
|---|---|---|---|
| 0 | \n", "26 | \n", "18000 | \n", "14 | \n", "
| 1 | \n", "43 | \n", "13000 | \n", "9 | \n", "
| 2 | \n", "55 | \n", "42000 | \n", "16 | \n", "
| 3 | \n", "55 | \n", "100000 | \n", "13 | \n", "
| 4 | \n", "55 | \n", "13000 | \n", "12 | \n", "
| \n", " | Age | \n", "Income | \n", "Hours | \n", "Enroll | \n", "
|---|---|---|---|---|
| 0 | \n", "-1.339746 | \n", "-1.132644 | \n", "0.817747 | \n", "1 | \n", "
| 1 | \n", "-0.114061 | \n", "-1.245143 | \n", "-0.253163 | \n", "0 | \n", "
| 2 | \n", "0.751128 | \n", "-0.592647 | \n", "1.246111 | \n", "1 | \n", "
| 3 | \n", "0.751128 | \n", "0.712346 | \n", "0.603565 | \n", "1 | \n", "
| 4 | \n", "0.751128 | \n", "-1.245143 | \n", "0.389383 | \n", "0 | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
| 995 | \n", "-0.402457 | \n", "-1.425142 | \n", "0.603565 | \n", "0 | \n", "
| 996 | \n", "-0.474556 | \n", "-1.425142 | \n", "-1.538255 | \n", "0 | \n", "
| 997 | \n", "0.751128 | \n", "0.037350 | \n", "1.460293 | \n", "1 | \n", "
| 998 | \n", "1.688417 | \n", "-0.097649 | \n", "1.031929 | \n", "1 | \n", "
| 999 | \n", "0.246435 | \n", "-0.030150 | \n", "1.674475 | \n", "0 | \n", "
1000 rows × 4 columns
\n", "| \n", " | Age | \n", "Income | \n", "Hours | \n", "
|---|---|---|---|
| 808 | \n", "0.030137 | \n", "0.824845 | \n", "1.674475 | \n", "
| 393 | \n", "-0.618755 | \n", "-0.975145 | \n", "0.175201 | \n", "
| 416 | \n", "-0.979250 | \n", "-1.425142 | \n", "0.817747 | \n", "
| 486 | \n", "0.679029 | \n", "-1.065144 | \n", "0.817747 | \n", "
| 422 | \n", "-0.114061 | \n", "-1.267643 | \n", "-1.324073 | \n", "
GridSearchCV(cv=5, estimator=KNeighborsClassifier(),\n",
" param_grid={'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]})In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. GridSearchCV(cv=5, estimator=KNeighborsClassifier(),\n",
" param_grid={'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]})KNeighborsClassifier()
KNeighborsClassifier()
| \n", " | mean_fit_time | \n", "std_fit_time | \n", "mean_score_time | \n", "std_score_time | \n", "param_n_neighbors | \n", "params | \n", "split0_test_score | \n", "split1_test_score | \n", "split2_test_score | \n", "split3_test_score | \n", "split4_test_score | \n", "mean_test_score | \n", "std_test_score | \n", "rank_test_score | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "0.001109 | \n", "0.000455 | \n", "0.003694 | \n", "0.000536 | \n", "1 | \n", "{'n_neighbors': 1} | \n", "0.866667 | \n", "0.908333 | \n", "0.866667 | \n", "0.900000 | \n", "0.916667 | \n", "0.891667 | \n", "0.021082 | \n", "9 | \n", "
| 1 | \n", "0.000886 | \n", "0.000090 | \n", "0.003221 | \n", "0.000229 | \n", "2 | \n", "{'n_neighbors': 2} | \n", "0.833333 | \n", "0.858333 | \n", "0.883333 | \n", "0.866667 | \n", "0.866667 | \n", "0.861667 | \n", "0.016330 | \n", "10 | \n", "
| 2 | \n", "0.000670 | \n", "0.000071 | \n", "0.002760 | \n", "0.000318 | \n", "3 | \n", "{'n_neighbors': 3} | \n", "0.908333 | \n", "0.916667 | \n", "0.883333 | \n", "0.933333 | \n", "0.900000 | \n", "0.908333 | \n", "0.016667 | \n", "2 | \n", "
| 3 | \n", "0.000545 | \n", "0.000020 | \n", "0.002311 | \n", "0.000066 | \n", "4 | \n", "{'n_neighbors': 4} | \n", "0.850000 | \n", "0.925000 | \n", "0.858333 | \n", "0.933333 | \n", "0.900000 | \n", "0.893333 | \n", "0.033912 | \n", "8 | \n", "
| 4 | \n", "0.000489 | \n", "0.000010 | \n", "0.002191 | \n", "0.000051 | \n", "5 | \n", "{'n_neighbors': 5} | \n", "0.875000 | \n", "0.941667 | \n", "0.883333 | \n", "0.950000 | \n", "0.925000 | \n", "0.915000 | \n", "0.030459 | \n", "1 | \n", "
| 5 | \n", "0.000457 | \n", "0.000006 | \n", "0.002055 | \n", "0.000038 | \n", "6 | \n", "{'n_neighbors': 6} | \n", "0.866667 | \n", "0.941667 | \n", "0.858333 | \n", "0.933333 | \n", "0.916667 | \n", "0.903333 | \n", "0.034400 | \n", "5 | \n", "
| 6 | \n", "0.000439 | \n", "0.000006 | \n", "0.001990 | \n", "0.000035 | \n", "7 | \n", "{'n_neighbors': 7} | \n", "0.875000 | \n", "0.933333 | \n", "0.858333 | \n", "0.941667 | \n", "0.916667 | \n", "0.905000 | \n", "0.032745 | \n", "4 | \n", "
| 7 | \n", "0.000436 | \n", "0.000005 | \n", "0.001974 | \n", "0.000008 | \n", "8 | \n", "{'n_neighbors': 8} | \n", "0.866667 | \n", "0.933333 | \n", "0.858333 | \n", "0.933333 | \n", "0.908333 | \n", "0.900000 | \n", "0.032059 | \n", "7 | \n", "
| 8 | \n", "0.000430 | \n", "0.000004 | \n", "0.001981 | \n", "0.000016 | \n", "9 | \n", "{'n_neighbors': 9} | \n", "0.866667 | \n", "0.925000 | \n", "0.883333 | \n", "0.941667 | \n", "0.916667 | \n", "0.906667 | \n", "0.027588 | \n", "3 | \n", "
| 9 | \n", "0.000472 | \n", "0.000047 | \n", "0.002084 | \n", "0.000079 | \n", "10 | \n", "{'n_neighbors': 10} | \n", "0.841667 | \n", "0.941667 | \n", "0.866667 | \n", "0.941667 | \n", "0.916667 | \n", "0.901667 | \n", "0.040620 | \n", "6 | \n", "
| \n", " | Actual | \n", "Predicted | \n", "
|---|---|---|
| 489 | \n", "1 | \n", "1 | \n", "
| 241 | \n", "0 | \n", "0 | \n", "
| 119 | \n", "0 | \n", "0 | \n", "
| 577 | \n", "0 | \n", "0 | \n", "
| 287 | \n", "0 | \n", "0 | \n", "
| ... | \n", "... | \n", "... | \n", "
| 804 | \n", "1 | \n", "1 | \n", "
| 974 | \n", "1 | \n", "1 | \n", "
| 810 | \n", "1 | \n", "1 | \n", "
| 395 | \n", "0 | \n", "0 | \n", "
| 861 | \n", "0 | \n", "0 | \n", "
400 rows × 2 columns
\n", "