{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "RcqBh1HEjvGX"
},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"df = pd.read_csv(\"https://raw.githubusercontent.com/kirenz/datasets/master/Hitters.csv\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mbzK2GHBCUeA"
},
"source": [
"Data frame called “Hitters” with 20 variables and 322 observations of major league players\n",
"\n",
"We want to predict a baseball player’s salary on the basis of various statistics associated with performance in the previous year."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FzH_hgP4_B-9"
},
"source": [
"A data frame with 322 observations of major league players on the following 20 variables.\n",
"- AtBat Number of times at bat in 1986\n",
"- Hits Number of hits in 1986\n",
"- HmRun Number of home runs in 1986\n",
"- Runs Number of runs in 1986\n",
"- RBI Number of runs batted in in 1986\n",
"- Walks Number of walks in 1986\n",
"- Years Number of years in the major leagues\n",
"- CAtBat Number of times at bat during his career\n",
"- CHits Number of hits during his career\n",
"- CHmRun Number of home runs during his career\n",
"- CRuns Number of runs during his career\n",
"- CRBI Number of runs batted in during his career\n",
"- CWalks Number of walks during his career\n",
"- League A factor with levels A and N indicating player’s league at the end of 1986\n",
"- Division A factor with levels E and W indicating player’s division at the end of 1986\n",
"- PutOuts Number of put outs in 1986\n",
"- Assists Number of assists in 1986\n",
"- Errors Number of errors in 1986\n",
"- Salary 1987 annual salary on opening day in thousands of dollars\n",
"- NewLeague A factor with levels A and N indicating player’s league at the beginning of 1987"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 424
},
"id": "76Sa_jrQpNzc",
"outputId": "9ec35b84-7d85-4e76-bcec-353ce6b517cb"
},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" AtBat \n",
" Hits \n",
" HmRun \n",
" Runs \n",
" RBI \n",
" Walks \n",
" Years \n",
" CAtBat \n",
" CHits \n",
" CHmRun \n",
" CRuns \n",
" CRBI \n",
" CWalks \n",
" League \n",
" Division \n",
" PutOuts \n",
" Assists \n",
" Errors \n",
" Salary \n",
" NewLeague \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" 293 \n",
" 66 \n",
" 1 \n",
" 30 \n",
" 29 \n",
" 14 \n",
" 1 \n",
" 293 \n",
" 66 \n",
" 1 \n",
" 30 \n",
" 29 \n",
" 14 \n",
" A \n",
" E \n",
" 446 \n",
" 33 \n",
" 20 \n",
" NaN \n",
" A \n",
" \n",
" \n",
" 1 \n",
" 315 \n",
" 81 \n",
" 7 \n",
" 24 \n",
" 38 \n",
" 39 \n",
" 14 \n",
" 3449 \n",
" 835 \n",
" 69 \n",
" 321 \n",
" 414 \n",
" 375 \n",
" N \n",
" W \n",
" 632 \n",
" 43 \n",
" 10 \n",
" 475.0 \n",
" N \n",
" \n",
" \n",
" 2 \n",
" 479 \n",
" 130 \n",
" 18 \n",
" 66 \n",
" 72 \n",
" 76 \n",
" 3 \n",
" 1624 \n",
" 457 \n",
" 63 \n",
" 224 \n",
" 266 \n",
" 263 \n",
" A \n",
" W \n",
" 880 \n",
" 82 \n",
" 14 \n",
" 480.0 \n",
" A \n",
" \n",
" \n",
" 3 \n",
" 496 \n",
" 141 \n",
" 20 \n",
" 65 \n",
" 78 \n",
" 37 \n",
" 11 \n",
" 5628 \n",
" 1575 \n",
" 225 \n",
" 828 \n",
" 838 \n",
" 354 \n",
" N \n",
" E \n",
" 200 \n",
" 11 \n",
" 3 \n",
" 500.0 \n",
" N \n",
" \n",
" \n",
" 4 \n",
" 321 \n",
" 87 \n",
" 10 \n",
" 39 \n",
" 42 \n",
" 30 \n",
" 2 \n",
" 396 \n",
" 101 \n",
" 12 \n",
" 48 \n",
" 46 \n",
" 33 \n",
" N \n",
" E \n",
" 805 \n",
" 40 \n",
" 4 \n",
" 91.5 \n",
" N \n",
" \n",
" \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" \n",
" \n",
" 317 \n",
" 497 \n",
" 127 \n",
" 7 \n",
" 65 \n",
" 48 \n",
" 37 \n",
" 5 \n",
" 2703 \n",
" 806 \n",
" 32 \n",
" 379 \n",
" 311 \n",
" 138 \n",
" N \n",
" E \n",
" 325 \n",
" 9 \n",
" 3 \n",
" 700.0 \n",
" N \n",
" \n",
" \n",
" 318 \n",
" 492 \n",
" 136 \n",
" 5 \n",
" 76 \n",
" 50 \n",
" 94 \n",
" 12 \n",
" 5511 \n",
" 1511 \n",
" 39 \n",
" 897 \n",
" 451 \n",
" 875 \n",
" A \n",
" E \n",
" 313 \n",
" 381 \n",
" 20 \n",
" 875.0 \n",
" A \n",
" \n",
" \n",
" 319 \n",
" 475 \n",
" 126 \n",
" 3 \n",
" 61 \n",
" 43 \n",
" 52 \n",
" 6 \n",
" 1700 \n",
" 433 \n",
" 7 \n",
" 217 \n",
" 93 \n",
" 146 \n",
" A \n",
" W \n",
" 37 \n",
" 113 \n",
" 7 \n",
" 385.0 \n",
" A \n",
" \n",
" \n",
" 320 \n",
" 573 \n",
" 144 \n",
" 9 \n",
" 85 \n",
" 60 \n",
" 78 \n",
" 8 \n",
" 3198 \n",
" 857 \n",
" 97 \n",
" 470 \n",
" 420 \n",
" 332 \n",
" A \n",
" E \n",
" 1314 \n",
" 131 \n",
" 12 \n",
" 960.0 \n",
" A \n",
" \n",
" \n",
" 321 \n",
" 631 \n",
" 170 \n",
" 9 \n",
" 77 \n",
" 44 \n",
" 31 \n",
" 11 \n",
" 4908 \n",
" 1457 \n",
" 30 \n",
" 775 \n",
" 357 \n",
" 249 \n",
" A \n",
" W \n",
" 408 \n",
" 4 \n",
" 3 \n",
" 1000.0 \n",
" A \n",
" \n",
" \n",
"
\n",
"
322 rows × 20 columns
\n",
"
"
],
"text/plain": [
" AtBat Hits HmRun Runs RBI Walks Years CAtBat CHits CHmRun \\\n",
"0 293 66 1 30 29 14 1 293 66 1 \n",
"1 315 81 7 24 38 39 14 3449 835 69 \n",
"2 479 130 18 66 72 76 3 1624 457 63 \n",
"3 496 141 20 65 78 37 11 5628 1575 225 \n",
"4 321 87 10 39 42 30 2 396 101 12 \n",
".. ... ... ... ... ... ... ... ... ... ... \n",
"317 497 127 7 65 48 37 5 2703 806 32 \n",
"318 492 136 5 76 50 94 12 5511 1511 39 \n",
"319 475 126 3 61 43 52 6 1700 433 7 \n",
"320 573 144 9 85 60 78 8 3198 857 97 \n",
"321 631 170 9 77 44 31 11 4908 1457 30 \n",
"\n",
" CRuns CRBI CWalks League Division PutOuts Assists Errors Salary \\\n",
"0 30 29 14 A E 446 33 20 NaN \n",
"1 321 414 375 N W 632 43 10 475.0 \n",
"2 224 266 263 A W 880 82 14 480.0 \n",
"3 828 838 354 N E 200 11 3 500.0 \n",
"4 48 46 33 N E 805 40 4 91.5 \n",
".. ... ... ... ... ... ... ... ... ... \n",
"317 379 311 138 N E 325 9 3 700.0 \n",
"318 897 451 875 A E 313 381 20 875.0 \n",
"319 217 93 146 A W 37 113 7 385.0 \n",
"320 470 420 332 A E 1314 131 12 960.0 \n",
"321 775 357 249 A W 408 4 3 1000.0 \n",
"\n",
" NewLeague \n",
"0 A \n",
"1 N \n",
"2 A \n",
"3 N \n",
"4 N \n",
".. ... \n",
"317 N \n",
"318 A \n",
"319 A \n",
"320 A \n",
"321 A \n",
"\n",
"[322 rows x 20 columns]"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "qwa2XRUJpVuS",
"outputId": "78a35cbe-cfb0-402a-d065-d8f95ea3e67b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"RangeIndex: 322 entries, 0 to 321\n",
"Data columns (total 20 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 AtBat 322 non-null int64 \n",
" 1 Hits 322 non-null int64 \n",
" 2 HmRun 322 non-null int64 \n",
" 3 Runs 322 non-null int64 \n",
" 4 RBI 322 non-null int64 \n",
" 5 Walks 322 non-null int64 \n",
" 6 Years 322 non-null int64 \n",
" 7 CAtBat 322 non-null int64 \n",
" 8 CHits 322 non-null int64 \n",
" 9 CHmRun 322 non-null int64 \n",
" 10 CRuns 322 non-null int64 \n",
" 11 CRBI 322 non-null int64 \n",
" 12 CWalks 322 non-null int64 \n",
" 13 League 322 non-null object \n",
" 14 Division 322 non-null object \n",
" 15 PutOuts 322 non-null int64 \n",
" 16 Assists 322 non-null int64 \n",
" 17 Errors 322 non-null int64 \n",
" 18 Salary 263 non-null float64\n",
" 19 NewLeague 322 non-null object \n",
"dtypes: float64(1), int64(16), object(3)\n",
"memory usage: 50.4+ KB\n"
]
}
],
"source": [
"df.info()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 320
},
"id": "AANCReIW_HY3",
"outputId": "7d9994af-961e-4331-a092-5e63550f09e8"
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" AtBat \n",
" Hits \n",
" HmRun \n",
" Runs \n",
" RBI \n",
" Walks \n",
" Years \n",
" CAtBat \n",
" CHits \n",
" CHmRun \n",
" CRuns \n",
" CRBI \n",
" CWalks \n",
" PutOuts \n",
" Assists \n",
" Errors \n",
" Salary \n",
" \n",
" \n",
" \n",
" \n",
" count \n",
" 322.000000 \n",
" 322.000000 \n",
" 322.000000 \n",
" 322.000000 \n",
" 322.000000 \n",
" 322.000000 \n",
" 322.000000 \n",
" 322.00000 \n",
" 322.000000 \n",
" 322.000000 \n",
" 322.000000 \n",
" 322.000000 \n",
" 322.000000 \n",
" 322.000000 \n",
" 322.000000 \n",
" 322.000000 \n",
" 263.000000 \n",
" \n",
" \n",
" mean \n",
" 380.928571 \n",
" 101.024845 \n",
" 10.770186 \n",
" 50.909938 \n",
" 48.027950 \n",
" 38.742236 \n",
" 7.444099 \n",
" 2648.68323 \n",
" 717.571429 \n",
" 69.490683 \n",
" 358.795031 \n",
" 330.118012 \n",
" 260.239130 \n",
" 288.937888 \n",
" 106.913043 \n",
" 8.040373 \n",
" 535.925882 \n",
" \n",
" \n",
" std \n",
" 153.404981 \n",
" 46.454741 \n",
" 8.709037 \n",
" 26.024095 \n",
" 26.166895 \n",
" 21.639327 \n",
" 4.926087 \n",
" 2324.20587 \n",
" 654.472627 \n",
" 86.266061 \n",
" 334.105886 \n",
" 333.219617 \n",
" 267.058085 \n",
" 280.704614 \n",
" 136.854876 \n",
" 6.368359 \n",
" 451.118681 \n",
" \n",
" \n",
" min \n",
" 16.000000 \n",
" 1.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 1.000000 \n",
" 19.00000 \n",
" 4.000000 \n",
" 0.000000 \n",
" 1.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 67.500000 \n",
" \n",
" \n",
" 25% \n",
" 255.250000 \n",
" 64.000000 \n",
" 4.000000 \n",
" 30.250000 \n",
" 28.000000 \n",
" 22.000000 \n",
" 4.000000 \n",
" 816.75000 \n",
" 209.000000 \n",
" 14.000000 \n",
" 100.250000 \n",
" 88.750000 \n",
" 67.250000 \n",
" 109.250000 \n",
" 7.000000 \n",
" 3.000000 \n",
" 190.000000 \n",
" \n",
" \n",
" 50% \n",
" 379.500000 \n",
" 96.000000 \n",
" 8.000000 \n",
" 48.000000 \n",
" 44.000000 \n",
" 35.000000 \n",
" 6.000000 \n",
" 1928.00000 \n",
" 508.000000 \n",
" 37.500000 \n",
" 247.000000 \n",
" 220.500000 \n",
" 170.500000 \n",
" 212.000000 \n",
" 39.500000 \n",
" 6.000000 \n",
" 425.000000 \n",
" \n",
" \n",
" 75% \n",
" 512.000000 \n",
" 137.000000 \n",
" 16.000000 \n",
" 69.000000 \n",
" 64.750000 \n",
" 53.000000 \n",
" 11.000000 \n",
" 3924.25000 \n",
" 1059.250000 \n",
" 90.000000 \n",
" 526.250000 \n",
" 426.250000 \n",
" 339.250000 \n",
" 325.000000 \n",
" 166.000000 \n",
" 11.000000 \n",
" 750.000000 \n",
" \n",
" \n",
" max \n",
" 687.000000 \n",
" 238.000000 \n",
" 40.000000 \n",
" 130.000000 \n",
" 121.000000 \n",
" 105.000000 \n",
" 24.000000 \n",
" 14053.00000 \n",
" 4256.000000 \n",
" 548.000000 \n",
" 2165.000000 \n",
" 1659.000000 \n",
" 1566.000000 \n",
" 1378.000000 \n",
" 492.000000 \n",
" 32.000000 \n",
" 2460.000000 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" AtBat Hits HmRun Runs RBI Walks \\\n",
"count 322.000000 322.000000 322.000000 322.000000 322.000000 322.000000 \n",
"mean 380.928571 101.024845 10.770186 50.909938 48.027950 38.742236 \n",
"std 153.404981 46.454741 8.709037 26.024095 26.166895 21.639327 \n",
"min 16.000000 1.000000 0.000000 0.000000 0.000000 0.000000 \n",
"25% 255.250000 64.000000 4.000000 30.250000 28.000000 22.000000 \n",
"50% 379.500000 96.000000 8.000000 48.000000 44.000000 35.000000 \n",
"75% 512.000000 137.000000 16.000000 69.000000 64.750000 53.000000 \n",
"max 687.000000 238.000000 40.000000 130.000000 121.000000 105.000000 \n",
"\n",
" Years CAtBat CHits CHmRun CRuns \\\n",
"count 322.000000 322.00000 322.000000 322.000000 322.000000 \n",
"mean 7.444099 2648.68323 717.571429 69.490683 358.795031 \n",
"std 4.926087 2324.20587 654.472627 86.266061 334.105886 \n",
"min 1.000000 19.00000 4.000000 0.000000 1.000000 \n",
"25% 4.000000 816.75000 209.000000 14.000000 100.250000 \n",
"50% 6.000000 1928.00000 508.000000 37.500000 247.000000 \n",
"75% 11.000000 3924.25000 1059.250000 90.000000 526.250000 \n",
"max 24.000000 14053.00000 4256.000000 548.000000 2165.000000 \n",
"\n",
" CRBI CWalks PutOuts Assists Errors \\\n",
"count 322.000000 322.000000 322.000000 322.000000 322.000000 \n",
"mean 330.118012 260.239130 288.937888 106.913043 8.040373 \n",
"std 333.219617 267.058085 280.704614 136.854876 6.368359 \n",
"min 0.000000 0.000000 0.000000 0.000000 0.000000 \n",
"25% 88.750000 67.250000 109.250000 7.000000 3.000000 \n",
"50% 220.500000 170.500000 212.000000 39.500000 6.000000 \n",
"75% 426.250000 339.250000 325.000000 166.000000 11.000000 \n",
"max 1659.000000 1566.000000 1378.000000 492.000000 32.000000 \n",
"\n",
" Salary \n",
"count 263.000000 \n",
"mean 535.925882 \n",
"std 451.118681 \n",
"min 67.500000 \n",
"25% 190.000000 \n",
"50% 425.000000 \n",
"75% 750.000000 \n",
"max 2460.000000 "
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.describe()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1M9C2NDVpY9f",
"outputId": "fac8ac8f-7fff-417a-b5b9-19260407e84d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"AtBat 0\n",
"Hits 0\n",
"HmRun 0\n",
"Runs 0\n",
"RBI 0\n",
"Walks 0\n",
"Years 0\n",
"CAtBat 0\n",
"CHits 0\n",
"CHmRun 0\n",
"CRuns 0\n",
"CRBI 0\n",
"CWalks 0\n",
"League 0\n",
"Division 0\n",
"PutOuts 0\n",
"Assists 0\n",
"Errors 0\n",
"Salary 59\n",
"NewLeague 0\n",
"dtype: int64\n"
]
}
],
"source": [
"print(df.isnull().sum())"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "08xUiroOpcBg"
},
"outputs": [],
"source": [
"# drop missing cases\n",
"df = df.dropna()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "lCyrOVNApedI"
},
"outputs": [],
"source": [
"dummies = pd.get_dummies(df[['League', 'Division','NewLeague']])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nQY7hd8PphJo",
"outputId": "f3f522e8-d044-472d-ca9c-538df73968b3"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Index: 263 entries, 1 to 321\n",
"Data columns (total 6 columns):\n",
" # Column Non-Null Count Dtype\n",
"--- ------ -------------- -----\n",
" 0 League_A 263 non-null bool \n",
" 1 League_N 263 non-null bool \n",
" 2 Division_E 263 non-null bool \n",
" 3 Division_W 263 non-null bool \n",
" 4 NewLeague_A 263 non-null bool \n",
" 5 NewLeague_N 263 non-null bool \n",
"dtypes: bool(6)\n",
"memory usage: 3.6 KB\n"
]
}
],
"source": [
"dummies.info()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "B7G4lBs3pmLY",
"outputId": "4e3964ba-c065-4750-f4d8-0cf893ffe7c4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" League_A League_N Division_E Division_W NewLeague_A NewLeague_N\n",
"1 False True False True False True\n",
"2 True False False True True False\n",
"3 False True True False False True\n",
"4 False True True False False True\n",
"5 True False False True True False\n"
]
}
],
"source": [
"print(dummies.head())"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "CTlRS474pott"
},
"outputs": [],
"source": [
"y = df['Salary']"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "8UmwTAHcprrq"
},
"outputs": [],
"source": [
"X_numerical = df.drop(['Salary', 'League', 'Division', 'NewLeague'], axis=1).astype('float64')"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "sOACqH5VpuZ2",
"outputId": "c11db38f-e1b6-4bdd-b69e-3ef22b4abe36"
},
"outputs": [
{
"data": {
"text/plain": [
"Index(['AtBat', 'Hits', 'HmRun', 'Runs', 'RBI', 'Walks', 'Years', 'CAtBat',\n",
" 'CHits', 'CHmRun', 'CRuns', 'CRBI', 'CWalks', 'PutOuts', 'Assists',\n",
" 'Errors'],\n",
" dtype='object')"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list_numerical = X_numerical.columns\n",
"list_numerical"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "LW5DgH3ZpxSd",
"outputId": "fb64a024-de92-4a35-8b7d-b66492bfac5f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Index: 263 entries, 1 to 321\n",
"Data columns (total 19 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 AtBat 263 non-null float64\n",
" 1 Hits 263 non-null float64\n",
" 2 HmRun 263 non-null float64\n",
" 3 Runs 263 non-null float64\n",
" 4 RBI 263 non-null float64\n",
" 5 Walks 263 non-null float64\n",
" 6 Years 263 non-null float64\n",
" 7 CAtBat 263 non-null float64\n",
" 8 CHits 263 non-null float64\n",
" 9 CHmRun 263 non-null float64\n",
" 10 CRuns 263 non-null float64\n",
" 11 CRBI 263 non-null float64\n",
" 12 CWalks 263 non-null float64\n",
" 13 PutOuts 263 non-null float64\n",
" 14 Assists 263 non-null float64\n",
" 15 Errors 263 non-null float64\n",
" 16 League_N 263 non-null bool \n",
" 17 Division_W 263 non-null bool \n",
" 18 NewLeague_N 263 non-null bool \n",
"dtypes: bool(3), float64(16)\n",
"memory usage: 35.7 KB\n"
]
}
],
"source": [
"# Create all features\n",
"X = pd.concat([X_numerical, dummies[['League_N', 'Division_W', 'NewLeague_N']]], axis=1)\n",
"X.info()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "82pBGm_Zp0j8"
},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=10)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
},
"id": "4-ct1-Vxp2sX",
"outputId": "5416954e-1c53-4be5-9de7-c27416f06152"
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" AtBat \n",
" Hits \n",
" HmRun \n",
" Runs \n",
" RBI \n",
" Walks \n",
" Years \n",
" CAtBat \n",
" CHits \n",
" CHmRun \n",
" CRuns \n",
" CRBI \n",
" CWalks \n",
" PutOuts \n",
" Assists \n",
" Errors \n",
" League_N \n",
" Division_W \n",
" NewLeague_N \n",
" \n",
" \n",
" \n",
" \n",
" 260 \n",
" 496.0 \n",
" 119.0 \n",
" 8.0 \n",
" 57.0 \n",
" 33.0 \n",
" 21.0 \n",
" 7.0 \n",
" 3358.0 \n",
" 882.0 \n",
" 36.0 \n",
" 365.0 \n",
" 280.0 \n",
" 165.0 \n",
" 155.0 \n",
" 371.0 \n",
" 29.0 \n",
" True \n",
" True \n",
" True \n",
" \n",
" \n",
" 92 \n",
" 317.0 \n",
" 78.0 \n",
" 7.0 \n",
" 35.0 \n",
" 35.0 \n",
" 32.0 \n",
" 1.0 \n",
" 317.0 \n",
" 78.0 \n",
" 7.0 \n",
" 35.0 \n",
" 35.0 \n",
" 32.0 \n",
" 45.0 \n",
" 122.0 \n",
" 26.0 \n",
" False \n",
" False \n",
" False \n",
" \n",
" \n",
" 137 \n",
" 343.0 \n",
" 103.0 \n",
" 6.0 \n",
" 48.0 \n",
" 36.0 \n",
" 40.0 \n",
" 15.0 \n",
" 4338.0 \n",
" 1193.0 \n",
" 70.0 \n",
" 581.0 \n",
" 421.0 \n",
" 325.0 \n",
" 211.0 \n",
" 56.0 \n",
" 13.0 \n",
" False \n",
" False \n",
" False \n",
" \n",
" \n",
" 90 \n",
" 314.0 \n",
" 83.0 \n",
" 13.0 \n",
" 39.0 \n",
" 46.0 \n",
" 16.0 \n",
" 5.0 \n",
" 1457.0 \n",
" 405.0 \n",
" 28.0 \n",
" 156.0 \n",
" 159.0 \n",
" 76.0 \n",
" 533.0 \n",
" 40.0 \n",
" 4.0 \n",
" False \n",
" True \n",
" False \n",
" \n",
" \n",
" 100 \n",
" 495.0 \n",
" 151.0 \n",
" 17.0 \n",
" 61.0 \n",
" 84.0 \n",
" 78.0 \n",
" 10.0 \n",
" 5624.0 \n",
" 1679.0 \n",
" 275.0 \n",
" 884.0 \n",
" 1015.0 \n",
" 709.0 \n",
" 1045.0 \n",
" 88.0 \n",
" 13.0 \n",
" False \n",
" False \n",
" False \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" AtBat Hits HmRun Runs RBI Walks Years CAtBat CHits CHmRun \\\n",
"260 496.0 119.0 8.0 57.0 33.0 21.0 7.0 3358.0 882.0 36.0 \n",
"92 317.0 78.0 7.0 35.0 35.0 32.0 1.0 317.0 78.0 7.0 \n",
"137 343.0 103.0 6.0 48.0 36.0 40.0 15.0 4338.0 1193.0 70.0 \n",
"90 314.0 83.0 13.0 39.0 46.0 16.0 5.0 1457.0 405.0 28.0 \n",
"100 495.0 151.0 17.0 61.0 84.0 78.0 10.0 5624.0 1679.0 275.0 \n",
"\n",
" CRuns CRBI CWalks PutOuts Assists Errors League_N Division_W \\\n",
"260 365.0 280.0 165.0 155.0 371.0 29.0 True True \n",
"92 35.0 35.0 32.0 45.0 122.0 26.0 False False \n",
"137 581.0 421.0 325.0 211.0 56.0 13.0 False False \n",
"90 156.0 159.0 76.0 533.0 40.0 4.0 False True \n",
"100 884.0 1015.0 709.0 1045.0 88.0 13.0 False False \n",
"\n",
" NewLeague_N \n",
"260 True \n",
"92 False \n",
"137 False \n",
"90 False \n",
"100 False "
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train.head()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"id": "pa7qR4wIp5-D"
},
"outputs": [],
"source": [
"\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"scaler = StandardScaler().fit(X_train[list_numerical])\n",
"\n",
"X_train[list_numerical] = scaler.transform(X_train[list_numerical])\n",
"\n",
"X_test[list_numerical] = scaler.transform(X_test[list_numerical])"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 424
},
"id": "2GSApakCp8y1",
"outputId": "eb220133-210b-4340-bae0-3ed51c3fa588"
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" AtBat \n",
" Hits \n",
" HmRun \n",
" Runs \n",
" RBI \n",
" Walks \n",
" Years \n",
" CAtBat \n",
" CHits \n",
" CHmRun \n",
" CRuns \n",
" CRBI \n",
" CWalks \n",
" PutOuts \n",
" Assists \n",
" Errors \n",
" League_N \n",
" Division_W \n",
" NewLeague_N \n",
" \n",
" \n",
" \n",
" \n",
" 260 \n",
" 0.644577 \n",
" 0.257439 \n",
" -0.456963 \n",
" 0.101010 \n",
" -0.763917 \n",
" -0.975959 \n",
" -0.070553 \n",
" 0.298535 \n",
" 0.239063 \n",
" -0.407836 \n",
" 0.011298 \n",
" -0.163736 \n",
" -0.361084 \n",
" -0.482387 \n",
" 1.746229 \n",
" 3.022233 \n",
" True \n",
" True \n",
" True \n",
" \n",
" \n",
" 92 \n",
" -0.592807 \n",
" -0.671359 \n",
" -0.572936 \n",
" -0.778318 \n",
" -0.685806 \n",
" -0.458312 \n",
" -1.306911 \n",
" -1.001403 \n",
" -0.969702 \n",
" -0.746705 \n",
" -0.957639 \n",
" -0.898919 \n",
" -0.844319 \n",
" -0.851547 \n",
" 0.022276 \n",
" 2.574735 \n",
" False \n",
" False \n",
" False \n",
" \n",
" \n",
" 137 \n",
" -0.413075 \n",
" -0.105019 \n",
" -0.688910 \n",
" -0.258715 \n",
" -0.646751 \n",
" -0.081841 \n",
" 1.577925 \n",
" 0.717456 \n",
" 0.706633 \n",
" -0.010542 \n",
" 0.645511 \n",
" 0.259369 \n",
" 0.220252 \n",
" -0.294452 \n",
" -0.434676 \n",
" 0.635577 \n",
" False \n",
" False \n",
" False \n",
" \n",
" \n",
" 90 \n",
" -0.613545 \n",
" -0.558091 \n",
" 0.122907 \n",
" -0.618440 \n",
" -0.256196 \n",
" -1.211253 \n",
" -0.482672 \n",
" -0.514087 \n",
" -0.478077 \n",
" -0.501317 \n",
" -0.602362 \n",
" -0.526826 \n",
" -0.684451 \n",
" 0.786178 \n",
" -0.545452 \n",
" -0.706917 \n",
" False \n",
" True \n",
" False \n",
" \n",
" \n",
" 100 \n",
" 0.637665 \n",
" 0.982354 \n",
" 0.586803 \n",
" 0.260888 \n",
" 1.227914 \n",
" 1.706394 \n",
" 0.547626 \n",
" 1.267183 \n",
" 1.437305 \n",
" 2.384908 \n",
" 1.535171 \n",
" 2.041811 \n",
" 1.615457 \n",
" 2.504446 \n",
" -0.213124 \n",
" 0.635577 \n",
" False \n",
" False \n",
" False \n",
" \n",
" \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" \n",
" \n",
" 274 \n",
" 0.824309 \n",
" 0.733164 \n",
" 0.470829 \n",
" 0.740521 \n",
" 0.954525 \n",
" 0.859335 \n",
" -0.688732 \n",
" -0.824858 \n",
" -0.808834 \n",
" -0.571428 \n",
" -0.787341 \n",
" -0.685866 \n",
" -0.648118 \n",
" 3.427344 \n",
" 0.326910 \n",
" 1.232241 \n",
" True \n",
" False \n",
" True \n",
" \n",
" \n",
" 196 \n",
" 0.423369 \n",
" 0.461321 \n",
" 1.862516 \n",
" 0.500704 \n",
" 1.618469 \n",
" 0.482865 \n",
" 1.165805 \n",
" 1.354814 \n",
" 1.246368 \n",
" 1.625375 \n",
" 1.112362 \n",
" 1.516681 \n",
" 0.681687 \n",
" -1.002566 \n",
" -0.822392 \n",
" -1.303581 \n",
" False \n",
" True \n",
" False \n",
" \n",
" \n",
" 159 \n",
" 1.474109 \n",
" 1.254197 \n",
" 1.746542 \n",
" 1.140215 \n",
" 2.126191 \n",
" -0.458312 \n",
" -0.894792 \n",
" -0.522636 \n",
" -0.520174 \n",
" -0.068968 \n",
" -0.528958 \n",
" -0.322776 \n",
" -0.662651 \n",
" -0.633407 \n",
" 1.310048 \n",
" 0.933909 \n",
" False \n",
" True \n",
" False \n",
" \n",
" \n",
" 17 \n",
" -1.470728 \n",
" -1.396275 \n",
" -1.152806 \n",
" -1.217982 \n",
" -1.740306 \n",
" -1.258312 \n",
" -0.482672 \n",
" -0.932153 \n",
" -0.933620 \n",
" -0.770075 \n",
" -0.869554 \n",
" -0.934928 \n",
" -0.818885 \n",
" -0.660255 \n",
" 0.403069 \n",
" 1.083075 \n",
" False \n",
" True \n",
" False \n",
" \n",
" \n",
" 162 \n",
" -1.643547 \n",
" -1.554850 \n",
" -1.152806 \n",
" -1.657646 \n",
" -1.701250 \n",
" -1.211253 \n",
" -0.894792 \n",
" -1.053127 \n",
" -1.020819 \n",
" -0.805130 \n",
" -1.007554 \n",
" -0.973938 \n",
" -0.895185 \n",
" 0.111623 \n",
" -0.690846 \n",
" -1.005249 \n",
" False \n",
" True \n",
" True \n",
" \n",
" \n",
"
\n",
"
184 rows × 19 columns
\n",
"
"
],
"text/plain": [
" AtBat Hits HmRun Runs RBI Walks Years \\\n",
"260 0.644577 0.257439 -0.456963 0.101010 -0.763917 -0.975959 -0.070553 \n",
"92 -0.592807 -0.671359 -0.572936 -0.778318 -0.685806 -0.458312 -1.306911 \n",
"137 -0.413075 -0.105019 -0.688910 -0.258715 -0.646751 -0.081841 1.577925 \n",
"90 -0.613545 -0.558091 0.122907 -0.618440 -0.256196 -1.211253 -0.482672 \n",
"100 0.637665 0.982354 0.586803 0.260888 1.227914 1.706394 0.547626 \n",
".. ... ... ... ... ... ... ... \n",
"274 0.824309 0.733164 0.470829 0.740521 0.954525 0.859335 -0.688732 \n",
"196 0.423369 0.461321 1.862516 0.500704 1.618469 0.482865 1.165805 \n",
"159 1.474109 1.254197 1.746542 1.140215 2.126191 -0.458312 -0.894792 \n",
"17 -1.470728 -1.396275 -1.152806 -1.217982 -1.740306 -1.258312 -0.482672 \n",
"162 -1.643547 -1.554850 -1.152806 -1.657646 -1.701250 -1.211253 -0.894792 \n",
"\n",
" CAtBat CHits CHmRun CRuns CRBI CWalks PutOuts \\\n",
"260 0.298535 0.239063 -0.407836 0.011298 -0.163736 -0.361084 -0.482387 \n",
"92 -1.001403 -0.969702 -0.746705 -0.957639 -0.898919 -0.844319 -0.851547 \n",
"137 0.717456 0.706633 -0.010542 0.645511 0.259369 0.220252 -0.294452 \n",
"90 -0.514087 -0.478077 -0.501317 -0.602362 -0.526826 -0.684451 0.786178 \n",
"100 1.267183 1.437305 2.384908 1.535171 2.041811 1.615457 2.504446 \n",
".. ... ... ... ... ... ... ... \n",
"274 -0.824858 -0.808834 -0.571428 -0.787341 -0.685866 -0.648118 3.427344 \n",
"196 1.354814 1.246368 1.625375 1.112362 1.516681 0.681687 -1.002566 \n",
"159 -0.522636 -0.520174 -0.068968 -0.528958 -0.322776 -0.662651 -0.633407 \n",
"17 -0.932153 -0.933620 -0.770075 -0.869554 -0.934928 -0.818885 -0.660255 \n",
"162 -1.053127 -1.020819 -0.805130 -1.007554 -0.973938 -0.895185 0.111623 \n",
"\n",
" Assists Errors League_N Division_W NewLeague_N \n",
"260 1.746229 3.022233 True True True \n",
"92 0.022276 2.574735 False False False \n",
"137 -0.434676 0.635577 False False False \n",
"90 -0.545452 -0.706917 False True False \n",
"100 -0.213124 0.635577 False False False \n",
".. ... ... ... ... ... \n",
"274 0.326910 1.232241 True False True \n",
"196 -0.822392 -1.303581 False True False \n",
"159 1.310048 0.933909 False True False \n",
"17 0.403069 1.083075 False True False \n",
"162 -0.690846 -1.005249 False True True \n",
"\n",
"[184 rows x 19 columns]"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 81
},
"id": "zU4A-f7_p_ho",
"outputId": "719670d9-fdc1-4814-e167-0018df873b3e"
},
"outputs": [
{
"data": {
"text/html": [
"Lasso(alpha=1) In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. "
],
"text/plain": [
"Lasso(alpha=1)"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.linear_model import Lasso\n",
"\n",
"reg = Lasso(alpha=1)\n",
"reg.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "eAW1NYJoqDX3",
"outputId": "ea7314f4-cdff-4e4d-ba4a-b55b186f077a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"R squared training set 60.43\n",
"R squared test set 33.01\n"
]
}
],
"source": [
"print('R squared training set', round(reg.score(X_train, y_train)*100, 2))\n",
"print('R squared test set', round(reg.score(X_test, y_test)*100, 2))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "rH-BG-7GqF0o",
"outputId": "043c34fb-deac-4228-e4d7-0181bdc02e94"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MSE training set 80571.73\n",
"MSE test set 134426.33\n"
]
}
],
"source": [
"from sklearn.metrics import mean_squared_error\n",
"\n",
"# Training data\n",
"pred_train = reg.predict(X_train)\n",
"mse_train = mean_squared_error(y_train, pred_train)\n",
"print('MSE training set', round(mse_train, 2))\n",
"\n",
"# Test data\n",
"pred = reg.predict(X_test)\n",
"mse_test =mean_squared_error(y_test, pred)\n",
"print('MSE test set', round(mse_test, 2))"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 476
},
"id": "IbZ_tni5qJof",
"outputId": "f346a19a-23cf-4216-f7bf-83a4959594fd"
},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"alphas = np.linspace(0.01,500,100)\n",
"lasso = Lasso(max_iter=10000)\n",
"coefs = []\n",
"\n",
"for a in alphas:\n",
" lasso.set_params(alpha=a)\n",
" lasso.fit(X_train, y_train)\n",
" coefs.append(lasso.coef_)\n",
"\n",
"ax = plt.gca()\n",
"\n",
"ax.plot(alphas, coefs)\n",
"ax.set_xscale('log')\n",
"plt.axis('tight')\n",
"plt.xlabel('alpha')\n",
"plt.ylabel('Standardized Coefficients')\n",
"plt.title('Lasso coefficients as a function of alpha');"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 81
},
"id": "6cQgBrJSqN2F",
"outputId": "f795eaab-8973-4405-f1a6-d83f991566aa"
},
"outputs": [
{
"data": {
"text/html": [
"LassoCV(cv=5, max_iter=10000, random_state=0) In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. "
],
"text/plain": [
"LassoCV(cv=5, max_iter=10000, random_state=0)"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.linear_model import LassoCV\n",
"\n",
"# Lasso with 5 fold cross-validation\n",
"model = LassoCV(cv=5, random_state=0, max_iter=10000)\n",
"\n",
"# Fit model\n",
"model.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "pVGC_8SrqQyh",
"outputId": "005b2b45-559c-40aa-e82b-7b59b3017eb0"
},
"outputs": [
{
"data": {
"text/plain": [
"2.3441244939374593"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.alpha_"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 81
},
"id": "b9-S9TPuqTq-",
"outputId": "2cf57a5e-b558-4a62-e19e-fd85189cbfdd"
},
"outputs": [
{
"data": {
"text/html": [
"Lasso(alpha=2.3441244939374593) In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. "
],
"text/plain": [
"Lasso(alpha=2.3441244939374593)"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Set best alpha\n",
"lasso_best = Lasso(alpha=model.alpha_)\n",
"lasso_best.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "xxov_yRIqWMs",
"outputId": "a5fdc604-9bc6-45b9-8462-be7210369915"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[(-176.45309657050498, 'AtBat'), (271.23333276345323, 'Hits'), (-13.049492223041677, 'HmRun'), (-48.97878412496759, 'Runs'), (-13.83696437015553, 'RBI'), (140.12896436568295, 'Walks'), (-10.616534012348882, 'Years'), (-0.0, 'CAtBat'), (0.0, 'CHits'), (78.65781330867388, 'CHmRun'), (355.66188056426347, 'CRuns'), (60.50548334806944, 'CRBI'), (-262.7512352402544, 'CWalks'), (65.61587416521267, 'PutOuts'), (-0.14505342495227297, 'Assists'), (-1.2293157493169835, 'Errors'), (99.66112742179898, 'League_N'), (-116.86405569164934, 'Division_W'), (-69.87497671182551, 'NewLeague_N')]\n"
]
}
],
"source": [
"print(list(zip(lasso_best.coef_, X)))"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 645
},
"id": "I65MlKax_YRf",
"outputId": "78875bd6-6789-4105-a2b3-843fbfd19868"
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" 0 \n",
" 1 \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" -176.453097 \n",
" AtBat \n",
" \n",
" \n",
" 1 \n",
" 271.233333 \n",
" Hits \n",
" \n",
" \n",
" 2 \n",
" -13.049492 \n",
" HmRun \n",
" \n",
" \n",
" 3 \n",
" -48.978784 \n",
" Runs \n",
" \n",
" \n",
" 4 \n",
" -13.836964 \n",
" RBI \n",
" \n",
" \n",
" 5 \n",
" 140.128964 \n",
" Walks \n",
" \n",
" \n",
" 6 \n",
" -10.616534 \n",
" Years \n",
" \n",
" \n",
" 7 \n",
" -0.000000 \n",
" CAtBat \n",
" \n",
" \n",
" 8 \n",
" 0.000000 \n",
" CHits \n",
" \n",
" \n",
" 9 \n",
" 78.657813 \n",
" CHmRun \n",
" \n",
" \n",
" 10 \n",
" 355.661881 \n",
" CRuns \n",
" \n",
" \n",
" 11 \n",
" 60.505483 \n",
" CRBI \n",
" \n",
" \n",
" 12 \n",
" -262.751235 \n",
" CWalks \n",
" \n",
" \n",
" 13 \n",
" 65.615874 \n",
" PutOuts \n",
" \n",
" \n",
" 14 \n",
" -0.145053 \n",
" Assists \n",
" \n",
" \n",
" 15 \n",
" -1.229316 \n",
" Errors \n",
" \n",
" \n",
" 16 \n",
" 99.661127 \n",
" League_N \n",
" \n",
" \n",
" 17 \n",
" -116.864056 \n",
" Division_W \n",
" \n",
" \n",
" 18 \n",
" -69.874977 \n",
" NewLeague_N \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" 0 1\n",
"0 -176.453097 AtBat\n",
"1 271.233333 Hits\n",
"2 -13.049492 HmRun\n",
"3 -48.978784 Runs\n",
"4 -13.836964 RBI\n",
"5 140.128964 Walks\n",
"6 -10.616534 Years\n",
"7 -0.000000 CAtBat\n",
"8 0.000000 CHits\n",
"9 78.657813 CHmRun\n",
"10 355.661881 CRuns\n",
"11 60.505483 CRBI\n",
"12 -262.751235 CWalks\n",
"13 65.615874 PutOuts\n",
"14 -0.145053 Assists\n",
"15 -1.229316 Errors\n",
"16 99.661127 League_N\n",
"17 -116.864056 Division_W\n",
"18 -69.874977 NewLeague_N"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame(list(zip(lasso_best.coef_, X)))"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "GX-W2vgEqZo2",
"outputId": "25375559-dd75-4cbc-e2a8-12d0decb82ec"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"R squared training set 59.18\n"
]
}
],
"source": [
"print('R squared training set', round(lasso_best.score(X_train, y_train)*100, 2))\n"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "A4DLhPnkqcW5",
"outputId": "c0986990-4699-496d-eba6-d56f03706e78"
},
"outputs": [
{
"data": {
"text/plain": [
"129468.59746481"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mean_squared_error(y_test, lasso_best.predict(X_test))"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 476
},
"id": "HILTw_wpqeuO",
"outputId": "4c9dad70-e091-4721-8496-e3cd3ffaa1e2"
},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.semilogx(model.alphas_, model.mse_path_, \":\")\n",
"plt.plot(\n",
" model.alphas_ ,\n",
" model.mse_path_.mean(axis=-1),\n",
" \"k\",\n",
" label=\"Average across the folds\",\n",
" linewidth=2,\n",
")\n",
"plt.axvline(\n",
" model.alpha_, linestyle=\"--\", color=\"k\", label=\"alpha: CV estimate\"\n",
")\n",
"\n",
"plt.legend()\n",
"plt.xlabel(\"alphas\")\n",
"plt.ylabel(\"Mean square error\")\n",
"plt.title(\"Mean square error on each fold\")\n",
"plt.axis(\"tight\")\n",
"\n",
"ymin, ymax = 50000, 250000\n",
"plt.ylim(ymin, ymax);"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "tC6I9Qe4x8i3",
"outputId": "fd33b24d-15be-447a-9193-2764db9949e8"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Ridge Regression:\n",
"MSE: 142408.39136111396\n",
"MSE: 79210.90860181357\n"
]
}
],
"source": [
"# Ridge Regression\n",
"from sklearn.linear_model import Ridge, ElasticNet\n",
"ridge_model = Ridge(alpha=0.1) # Adjust alpha for regularization strength\n",
"ridge_model.fit(X_train, y_train)\n",
"ridge_y_pred_test = ridge_model.predict(X_test)\n",
"ridge_y_pred_train = ridge_model.predict(X_train)\n",
"ridge_mse_test = mean_squared_error(y_test, ridge_y_pred_test)\n",
"ridge_mse_train = mean_squared_error(y_train, ridge_y_pred_train)\n",
"#ridge_r2 = r2_score(y_test, ridge_y_pred_)\n",
"print(\"\\nRidge Regression:\")\n",
"print(f\"MSE: {ridge_mse_test}\")\n",
"print(f\"MSE: {ridge_mse_train}\")\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "IYx1zAjWymjM",
"outputId": "7ad46dfe-ef51-4ad2-c02c-3ed39d481251"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Elastic Net Regression:\n",
"MSE: 128123.64676256798\n",
"MSE: 88255.34545236964\n"
]
}
],
"source": [
"# Elastic Net Regression\n",
"elastic_net_model = ElasticNet(alpha=0.1, l1_ratio=0.5) # Adjust alpha and l1_ratio\n",
"elastic_net_model.fit(X_train, y_train)\n",
"elastic_y_pred_train = elastic_net_model.predict(X_train)\n",
"elastic_net_y_pred = elastic_net_model.predict(X_test)\n",
"elastic_net_mse = mean_squared_error(y_test, elastic_net_y_pred)\n",
"#elastic_net_r2 = r2_score(y_test, elastic_net_y_pred)\n",
"elastic_mse_test = mean_squared_error(y_test, elastic_net_y_pred)\n",
"elastic_mse_train = mean_squared_error(y_train, elastic_y_pred_train)\n",
"print(\"\\nElastic Net Regression:\")\n",
"print(f\"MSE: {elastic_mse_test}\")\n",
"print(f\"MSE: {elastic_mse_train}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}