{ "cells": [ { "cell_type": "code", "execution_count": 64, "id": "01c46189-a598-4bfc-9565-a914346decf7", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "from sklearn.model_selection import train_test_split, GridSearchCV\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn.neighbors import KNeighborsClassifier\n", "# from sklearn.metrics import confusion_matrix, roc_curve, roc_auc_score, RocCurveDisplay\n", "from sklearn.metrics import confusion_matrix, recall_score, precision_score, roc_auc_score, roc_curve, accuracy_score, RocCurveDisplay\n", "\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 65, "id": "53f58e2c-169e-434c-9b8e-1bd0cb5f7715", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EnrollAgeIncomeHours
01261800014
1043130009
21554200016
315510000013
40551300012
\n", "
" ], "text/plain": [ " Enroll Age Income Hours\n", "0 1 26 18000 14\n", "1 0 43 13000 9\n", "2 1 55 42000 16\n", "3 1 55 100000 13\n", "4 0 55 13000 12" ] }, "execution_count": 65, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load the data\n", "df = pd.read_excel(\"gym.xlsx\")\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 66, "id": "b5809395-d339-4a18-bef1-ea12b06cc9f7", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EnrollAgeIncomeHours
count1000.0000001000.0000001000.0000001000.000000
mean0.40300044.58200068340.00000010.182000
std0.49074613.87673744466.9282474.671263
min0.00000021.0000001000.0000002.000000
25%0.00000032.00000031000.0000006.000000
50%0.00000045.00000064000.00000010.000000
75%1.00000057.00000097000.00000014.000000
max1.00000068.000000198000.00000018.000000
\n", "
" ], "text/plain": [ " Enroll Age Income Hours\n", "count 1000.000000 1000.000000 1000.000000 1000.000000\n", "mean 0.403000 44.582000 68340.000000 10.182000\n", "std 0.490746 13.876737 44466.928247 4.671263\n", "min 0.000000 21.000000 1000.000000 2.000000\n", "25% 0.000000 32.000000 31000.000000 6.000000\n", "50% 0.000000 45.000000 64000.000000 10.000000\n", "75% 1.000000 57.000000 97000.000000 14.000000\n", "max 1.000000 68.000000 198000.000000 18.000000" ] }, "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.describe()" ] }, { "cell_type": "code", "execution_count": 67, "id": "46f2b835-2f7d-4d65-a1e9-5dacaf6a370e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1000, 4)" ] }, "execution_count": 67, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.shape" ] }, { "cell_type": "code", "execution_count": 68, "id": "92834d50-c430-4431-b5d9-81b10415dca1", "metadata": {}, "outputs": [], "source": [ "indAtts = [\"Age\", \"Income\", \"Hours\"]\n", "depAtt = \"Enroll\"" ] }, { "cell_type": "code", "execution_count": 69, "id": "3b789418-6284-41cb-9c41-5f2b5d2114ac", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AgeIncomeHours
0261800014
143130009
2554200016
35510000013
4551300012
\n", "
" ], "text/plain": [ " Age Income Hours\n", "0 26 18000 14\n", "1 43 13000 9\n", "2 55 42000 16\n", "3 55 100000 13\n", "4 55 13000 12" ] }, "execution_count": 69, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Separate features and target variable\n", "Xs = df[indAtts]\n", "Xs.head()" ] }, { "cell_type": "code", "execution_count": 70, "id": "c9c82753-ed45-4d63-9f5a-75fd54d196fb", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 1\n", "1 0\n", "2 1\n", "3 1\n", "4 0\n", " ..\n", "995 0\n", "996 0\n", "997 1\n", "998 1\n", "999 0\n", "Name: Enroll, Length: 1000, dtype: int64" ] }, "execution_count": 70, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y = df[depAtt]\n", "y\n" ] }, { "cell_type": "code", "execution_count": 71, "id": "13b2648e-6a32-4775-9970-3e0c466271fe", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[-1.33974561, -1.13264376, 0.81774684],\n", " [-0.11406079, -1.24514314, -0.25316311],\n", " [ 0.7511285 , -0.59264674, 1.24611082],\n", " ...,\n", " [ 0.7511285 , 0.03734979, 1.46029281],\n", " [ 1.68841689, -0.09764946, 1.03192883],\n", " [ 0.24643475, -0.03014983, 1.6744748 ]])" ] }, "execution_count": 71, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Scale the features\n", "scaler = StandardScaler()\n", "XsScaled = scaler.fit_transform(Xs)\n", "XsScaled" ] }, { "cell_type": "code", "execution_count": 72, "id": "b3d1b271-1670-410e-a98b-c0af06571d70", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AgeIncomeHoursEnroll
0-1.339746-1.1326440.8177471
1-0.114061-1.245143-0.2531630
20.751128-0.5926471.2461111
30.7511280.7123460.6035651
40.751128-1.2451430.3893830
...............
995-0.402457-1.4251420.6035650
996-0.474556-1.425142-1.5382550
9970.7511280.0373501.4602931
9981.688417-0.0976491.0319291
9990.246435-0.0301501.6744750
\n", "

1000 rows × 4 columns

\n", "
" ], "text/plain": [ " Age Income Hours Enroll\n", "0 -1.339746 -1.132644 0.817747 1\n", "1 -0.114061 -1.245143 -0.253163 0\n", "2 0.751128 -0.592647 1.246111 1\n", "3 0.751128 0.712346 0.603565 1\n", "4 0.751128 -1.245143 0.389383 0\n", ".. ... ... ... ...\n", "995 -0.402457 -1.425142 0.603565 0\n", "996 -0.474556 -1.425142 -1.538255 0\n", "997 0.751128 0.037350 1.460293 1\n", "998 1.688417 -0.097649 1.031929 1\n", "999 0.246435 -0.030150 1.674475 0\n", "\n", "[1000 rows x 4 columns]" ] }, "execution_count": 72, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Combine scaled features and target into a new DataFrame\n", "dfScaled = pd.DataFrame(XsScaled, columns=Xs.columns)\n", "dfScaled[depAtt] = y.astype('category')\n", "dfScaled" ] }, { "cell_type": "code", "execution_count": 73, "id": "4e6402c1-05f3-4cfe-b96d-9758d36aed57", "metadata": {}, "outputs": [], "source": [ "# Split the data into training and testing sets\n", "X_train, X_test, y_train, y_test = train_test_split(\n", " dfScaled[indAtts],\n", " dfScaled[depAtt],\n", " test_size=0.4,\n", " random_state=1,\n", " stratify=dfScaled[depAtt]\n", ")\n" ] }, { "cell_type": "code", "execution_count": 74, "id": "a9b32cbd-f135-4c17-9b52-7078722caae3", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AgeIncomeHours
8080.0301370.8248451.674475
393-0.618755-0.9751450.175201
416-0.979250-1.4251420.817747
4860.679029-1.0651440.817747
422-0.114061-1.267643-1.324073
\n", "
" ], "text/plain": [ " Age Income Hours\n", "808 0.030137 0.824845 1.674475\n", "393 -0.618755 -0.975145 0.175201\n", "416 -0.979250 -1.425142 0.817747\n", "486 0.679029 -1.065144 0.817747\n", "422 -0.114061 -1.267643 -1.324073" ] }, "execution_count": 74, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train.head()" ] }, { "cell_type": "code", "execution_count": 75, "id": "2226b885-9cff-41aa-a0f4-35ef5682daaa", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
GridSearchCV(cv=5, estimator=KNeighborsClassifier(),\n",
       "             param_grid={'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]})
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "GridSearchCV(cv=5, estimator=KNeighborsClassifier(),\n", " param_grid={'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]})" ] }, "execution_count": 75, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Perform k-NN classification with cross-validation to find the best k\n", "knn = KNeighborsClassifier()\n", "param_grid = {'n_neighbors': list(range(1, 11))}\n", "grid_search = GridSearchCV(knn, param_grid, cv=5)\n", "grid_search.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 76, "id": "779e861d-04cf-4d8f-94e0-e42c96525b61", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Best parameters: {'n_neighbors': 5}\n", "Best cross-validation score: 0.915\n" ] } ], "source": [ "print(\"Best parameters:\", grid_search.best_params_)\n", "print(\"Best cross-validation score:\", grid_search.best_score_)" ] }, { "cell_type": "code", "execution_count": 77, "id": "046a78bc-4b77-44d9-a5d9-3d6084d1ad16", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
mean_fit_timestd_fit_timemean_score_timestd_score_timeparam_n_neighborsparamssplit0_test_scoresplit1_test_scoresplit2_test_scoresplit3_test_scoresplit4_test_scoremean_test_scorestd_test_scorerank_test_score
00.0011090.0004550.0036940.0005361{'n_neighbors': 1}0.8666670.9083330.8666670.9000000.9166670.8916670.0210829
10.0008860.0000900.0032210.0002292{'n_neighbors': 2}0.8333330.8583330.8833330.8666670.8666670.8616670.01633010
20.0006700.0000710.0027600.0003183{'n_neighbors': 3}0.9083330.9166670.8833330.9333330.9000000.9083330.0166672
30.0005450.0000200.0023110.0000664{'n_neighbors': 4}0.8500000.9250000.8583330.9333330.9000000.8933330.0339128
40.0004890.0000100.0021910.0000515{'n_neighbors': 5}0.8750000.9416670.8833330.9500000.9250000.9150000.0304591
50.0004570.0000060.0020550.0000386{'n_neighbors': 6}0.8666670.9416670.8583330.9333330.9166670.9033330.0344005
60.0004390.0000060.0019900.0000357{'n_neighbors': 7}0.8750000.9333330.8583330.9416670.9166670.9050000.0327454
70.0004360.0000050.0019740.0000088{'n_neighbors': 8}0.8666670.9333330.8583330.9333330.9083330.9000000.0320597
80.0004300.0000040.0019810.0000169{'n_neighbors': 9}0.8666670.9250000.8833330.9416670.9166670.9066670.0275883
90.0004720.0000470.0020840.00007910{'n_neighbors': 10}0.8416670.9416670.8666670.9416670.9166670.9016670.0406206
\n", "
" ], "text/plain": [ " mean_fit_time std_fit_time mean_score_time std_score_time \\\n", "0 0.001109 0.000455 0.003694 0.000536 \n", "1 0.000886 0.000090 0.003221 0.000229 \n", "2 0.000670 0.000071 0.002760 0.000318 \n", "3 0.000545 0.000020 0.002311 0.000066 \n", "4 0.000489 0.000010 0.002191 0.000051 \n", "5 0.000457 0.000006 0.002055 0.000038 \n", "6 0.000439 0.000006 0.001990 0.000035 \n", "7 0.000436 0.000005 0.001974 0.000008 \n", "8 0.000430 0.000004 0.001981 0.000016 \n", "9 0.000472 0.000047 0.002084 0.000079 \n", "\n", " param_n_neighbors params split0_test_score \\\n", "0 1 {'n_neighbors': 1} 0.866667 \n", "1 2 {'n_neighbors': 2} 0.833333 \n", "2 3 {'n_neighbors': 3} 0.908333 \n", "3 4 {'n_neighbors': 4} 0.850000 \n", "4 5 {'n_neighbors': 5} 0.875000 \n", "5 6 {'n_neighbors': 6} 0.866667 \n", "6 7 {'n_neighbors': 7} 0.875000 \n", "7 8 {'n_neighbors': 8} 0.866667 \n", "8 9 {'n_neighbors': 9} 0.866667 \n", "9 10 {'n_neighbors': 10} 0.841667 \n", "\n", " split1_test_score split2_test_score split3_test_score split4_test_score \\\n", "0 0.908333 0.866667 0.900000 0.916667 \n", "1 0.858333 0.883333 0.866667 0.866667 \n", "2 0.916667 0.883333 0.933333 0.900000 \n", "3 0.925000 0.858333 0.933333 0.900000 \n", "4 0.941667 0.883333 0.950000 0.925000 \n", "5 0.941667 0.858333 0.933333 0.916667 \n", "6 0.933333 0.858333 0.941667 0.916667 \n", "7 0.933333 0.858333 0.933333 0.908333 \n", "8 0.925000 0.883333 0.941667 0.916667 \n", "9 0.941667 0.866667 0.941667 0.916667 \n", "\n", " mean_test_score std_test_score rank_test_score \n", "0 0.891667 0.021082 9 \n", "1 0.861667 0.016330 10 \n", "2 0.908333 0.016667 2 \n", "3 0.893333 0.033912 8 \n", "4 0.915000 0.030459 1 \n", "5 0.903333 0.034400 5 \n", "6 0.905000 0.032745 4 \n", "7 0.900000 0.032059 7 \n", "8 0.906667 0.027588 3 \n", "9 0.901667 0.040620 6 " ] }, "execution_count": 77, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Display detailed results\n", "results = pd.DataFrame(grid_search.cv_results_)\n", "results" ] }, { "cell_type": "code", "execution_count": 78, "id": "1f78ee56-b9c8-4d1a-aef9-ad37dfa41858", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ActualPredicted
48911
24100
11900
57700
28700
.........
80411
97411
81011
39500
86100
\n", "

400 rows × 2 columns

\n", "
" ], "text/plain": [ " Actual Predicted\n", "489 1 1\n", "241 0 0\n", "119 0 0\n", "577 0 0\n", "287 0 0\n", ".. ... ...\n", "804 1 1\n", "974 1 1\n", "810 1 1\n", "395 0 0\n", "861 0 0\n", "\n", "[400 rows x 2 columns]" ] }, "execution_count": 78, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Make predictions on the test set\n", "best_knn = grid_search.best_estimator_\n", "predictions = best_knn.predict(X_test)\n", "# Combine y_test and predictions into a DataFrame\n", "results_df = pd.DataFrame({'Actual': y_test, 'Predicted': predictions})\n", "results_df" ] }, { "cell_type": "code", "execution_count": 80, "id": "777a6302-6139-4199-9ed0-68f5e8938612", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Confusion Matrix as DataFrame:\n", " Actual_0 Actual_1\n", "Predicted_0 217 22\n", "Predicted_1 17 144\n" ] } ], "source": [ "# Confusion matrix\n", "conf_matrix = confusion_matrix(y_test, predictions)\n", "\n", "# Convert confusion matrix to DataFrame with predicted as rows and actual as columns\n", "conf_matrix_df = pd.DataFrame(conf_matrix, index=['Predicted_0', 'Predicted_1'], columns=['Actual_0', 'Actual_1'])\n", "print(\"Confusion Matrix as DataFrame:\")\n", "print(conf_matrix_df)\n" ] }, { "cell_type": "code", "execution_count": 82, "id": "aa493759-f506-4278-a212-4a3a0b5db15a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.9025\n", "Precision: 0.8674698795180723\n", "Recall (Sensitivity): 0.8944099378881988\n", "Specificity: 0.9079497907949791\n", "F1 Score: 0.8807339449541285\n" ] } ], "source": [ "# Calculate metrics\n", "precision = precision_score(y_test, predictions)\n", "recall = recall_score(y_test, predictions)\n", "accuracy = accuracy_score(y_test, predictions)\n", "specificity = conf_matrix[0, 0] / (conf_matrix[0, 0] + conf_matrix[0, 1])\n", "f1_score = 2 * (precision * recall) / (precision + recall)\n", "\n", "print(f\"Accuracy: {accuracy}\")\n", "print(f\"Precision: {precision}\")\n", "print(f\"Recall (Sensitivity): {recall}\")\n", "print(f\"Specificity: {specificity}\")\n", "print(f\"F1 Score: {f1_score}\")" ] }, { "cell_type": "code", "execution_count": 84, "id": "c9a97f8e-0ebc-4b51-a4a1-ff059dd75535", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1. , 0. , 0.2, 0.4, 0.2, 0.6, 1. , 0.2, 0.8, 0. , 0. , 0.2, 0. ,\n", " 0. , 0.6, 0.6, 0. , 0. , 0. , 0.8, 1. , 0. , 1. , 0.8, 0. , 0.8,\n", " 0. , 0.2, 1. , 0. , 0.8, 0. , 0.2, 0.2, 0. , 0. , 0.4, 0.4, 0. ,\n", " 1. , 0. , 0.8, 0. , 0.8, 0. , 0.8, 0. , 0.6, 1. , 0.8, 0. , 1. ,\n", " 1. , 0.8, 0.4, 0. , 0.8, 0. , 0.2, 0. , 0.6, 1. , 0.6, 0. , 1. ,\n", " 0. , 0.8, 0. , 0. , 0.2, 0.2, 0.8, 1. , 0. , 0. , 0. , 0.8, 0.4,\n", " 1. , 0. , 0. , 1. , 0. , 0.2, 0. , 0.8, 0. , 1. , 0. , 0. , 0. ,\n", " 0. , 0. , 1. , 1. , 1. , 0. , 1. , 0. , 0. , 0.4, 0.2, 1. , 0. ,\n", " 1. , 0.8, 1. , 1. , 0. , 1. , 0.8, 1. , 1. , 1. , 0. , 0. , 1. ,\n", " 0.8, 0. , 0. , 0.2, 0.8, 0. , 1. , 1. , 1. , 0.8, 1. , 1. , 0.8,\n", " 0.2, 0. , 0. , 0. , 0.6, 0. , 1. , 0.2, 0. , 1. , 0. , 0. , 0.8,\n", " 0.6, 0. , 0. , 0. , 0.2, 0.4, 1. , 0.8, 0. , 1. , 0. , 0. , 1. ,\n", " 1. , 0. , 0.6, 0. , 0. , 0. , 0.2, 0.2, 0.4, 0.2, 0.2, 0.4, 1. ,\n", " 0.8, 0. , 0.8, 0. , 0. , 0.2, 1. , 0. , 0. , 0. , 0. , 1. , 0. ,\n", " 0. , 1. , 0. , 0.2, 0. , 1. , 0. , 0.2, 1. , 0. , 0.4, 0.2, 0.2,\n", " 0. , 0.6, 0.2, 0.6, 0.6, 0. , 0. , 0.8, 0.8, 0.8, 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0.4, 0. , 0. , 0.6, 1. , 1. , 0.2, 0. , 0.8, 0. ,\n", " 0. , 0. , 0. , 0.2, 1. , 1. , 0.2, 1. , 0.2, 0. , 0.8, 0. , 0. ,\n", " 0. , 1. , 0.4, 1. , 0. , 0. , 1. , 1. , 0. , 0.2, 1. , 0.6, 0.4,\n", " 0.8, 1. , 0. , 0. , 1. , 0.2, 0. , 1. , 0.8, 0.6, 0. , 0.8, 1. ,\n", " 0.4, 0. , 0. , 0.6, 0.6, 0. , 1. , 0.2, 1. , 1. , 0.6, 0.2, 0.8,\n", " 1. , 0.2, 0.2, 0. , 0. , 1. , 1. , 0.8, 0.8, 0. , 0.2, 1. , 0.8,\n", " 0. , 0.8, 0. , 1. , 0. , 0. , 1. , 0.8, 0.2, 0.2, 0.2, 0. , 0.6,\n", " 1. , 1. , 0.2, 0. , 0. , 0. , 0.2, 0.8, 1. , 1. , 1. , 0. , 0.8,\n", " 0.8, 0.4, 0. , 0.4, 0. , 1. , 0.4, 0.8, 0.4, 0.2, 0. , 0. , 0. ,\n", " 0. , 0. , 0. , 0.8, 0.8, 0. , 0.8, 0.8, 0. , 1. , 0.2, 0. , 0. ,\n", " 0.6, 1. , 0.8, 0. , 0. , 0. , 0. , 0.8, 0. , 0.6, 0. , 0.8, 0.2,\n", " 0. , 0. , 0. , 0.8, 0.6, 0. , 1. , 0. , 0.4, 0. , 0.4, 0. , 0.8,\n", " 0. , 0. , 0.2, 1. , 0.4, 0. , 0.2, 0.2, 0.8, 0. , 0.8, 0. , 0.2,\n", " 0.8, 0. , 0. , 0.8, 0.6, 0. , 0.8, 0. , 0.8, 0. , 1. , 0. , 0. ,\n", " 0.2, 0.8, 1. , 0.8, 0. , 0.6, 1. , 0.8, 0.2, 0.2])" ] }, "execution_count": 84, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Predict probabilities for ROC curve\n", "probs= best_knn.predict_proba(X_test)[:, 1]\n", "probs" ] }, { "cell_type": "code", "execution_count": 85, "id": "b23ed9f8-b1c0-491c-b048-112a0797299f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ROC AUC: 0.9423191870890616\n" ] } ], "source": [ "roc_auc = roc_auc_score(y_test, probs)\n", "print(\"ROC AUC:\", roc_auc)" ] }, { "cell_type": "code", "execution_count": 90, "id": "ffe52a48-533d-43db-b845-8aab740638ef", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot ROC curve\n", "fpr, tpr, _ = roc_curve(y_test, probs )\n", "plt.figure()\n", "plt.plot(fpr, tpr, label=f'ROC curve (area = {roc_auc:.2f})')\n", "plt.plot([0, 1], [0, 1], 'k--')\n", "plt.xlim([0.0, 1.0])\n", "plt.ylim([0.0, 1.05])\n", "plt.xlabel('False Positive Rate')\n", "plt.ylabel('True Positive Rate')\n", "plt.title('Receiver Operating Characteristic')\n", "plt.legend(loc=\"lower right\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "68206bb3-19e3-435a-a311-e82035c1d57a", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }