{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"machine_shape": "hm",
"gpuType": "V28"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "TPU"
},
"cells": [
{
"cell_type": "markdown",
"source": [
"### **Business Problem Definition:**\n",
"\n",
"An company has plans to enter new markets with their existing products (P1, P2, P3, P4 and P5). After intensive market research, they’ve decided that the behavior of new market is similar to their existing market.\n",
"\n",
"In their existing market, the sales team has classified all customers into 4 segments (A, B, C, D ). Then, they performed segmented outreach and communication for different segment of customers. This strategy has work exceptionally well for them. They plan to use the same strategy on new markets and have identified 2627 new potential customers.\n",
"\n",
"### As a business analyst you are required to `help the manager to predict the right group allocation` of the new customers."
],
"metadata": {
"id": "gKC9wLVUwOp3"
}
},
{
"cell_type": "markdown",
"source": [
"Variables Description\n",
"\n",
"ID --\tUnique ID\n",
"\n",
"Gender\t-- Gender of the customer\n",
"\n",
"Ever_Married\t-- Marital status of the customer\n",
"\n",
"Age\t-- Age of the customer\n",
"\n",
"Graduated\t-- Is the customer a graduate?\n",
"\n",
"Profession\t-- Profession of the customer\n",
"\n",
"Work_Experience\t-- Work Experience in years\n",
"\n",
"Spending_Score\t-- Spending score of the customer\n",
"\n",
"Family_Size\t-- Number of family members for the customer(including the customer)\n",
"\n",
"Var_1\t-- Anonymised Category for the customer\n",
"\n",
"Segmentation(target)\t-- Customer Segment of the customer"
],
"metadata": {
"id": "43Znaj5Av_fb"
}
},
{
"cell_type": "code",
"source": [
"# Importing libraries\n",
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt"
],
"metadata": {
"id": "DPo_-1D6wYGQ"
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Loading the train data\n",
"df = pd.read_csv('Train.csv')\n",
"\n",
"# Looking top 10 rows\n",
"df.head(10)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 363
},
"id": "PC9zyxL2wtyj",
"outputId": "1c778f9e-fc89-4ce9-e393-6fcb7b568af1"
},
"execution_count": 2,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" ID Gender Ever_Married Age Graduated Profession Work_Experience \\\n",
"0 462809 Male No 22 No Healthcare 1.0 \n",
"1 462643 Female Yes 38 Yes Engineer NaN \n",
"2 466315 Female Yes 67 Yes Engineer 1.0 \n",
"3 461735 Male Yes 67 Yes Lawyer 0.0 \n",
"4 462669 Female Yes 40 Yes Entertainment NaN \n",
"5 461319 Male Yes 56 No Artist 0.0 \n",
"6 460156 Male No 32 Yes Healthcare 1.0 \n",
"7 464347 Female No 33 Yes Healthcare 1.0 \n",
"8 465015 Female Yes 61 Yes Engineer 0.0 \n",
"9 465176 Female Yes 55 Yes Artist 1.0 \n",
"\n",
" Spending_Score Family_Size Var_1 Segmentation \n",
"0 Low 4.0 Cat_4 D \n",
"1 Average 3.0 Cat_4 A \n",
"2 Low 1.0 Cat_6 B \n",
"3 High 2.0 Cat_6 B \n",
"4 High 6.0 Cat_6 A \n",
"5 Average 2.0 Cat_6 C \n",
"6 Low 3.0 Cat_6 C \n",
"7 Low 3.0 Cat_6 D \n",
"8 Low 3.0 Cat_7 D \n",
"9 Average 4.0 Cat_6 C "
],
"text/html": [
"\n",
"
\n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" ID \n",
" Gender \n",
" Ever_Married \n",
" Age \n",
" Graduated \n",
" Profession \n",
" Work_Experience \n",
" Spending_Score \n",
" Family_Size \n",
" Var_1 \n",
" Segmentation \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" 462809 \n",
" Male \n",
" No \n",
" 22 \n",
" No \n",
" Healthcare \n",
" 1.0 \n",
" Low \n",
" 4.0 \n",
" Cat_4 \n",
" D \n",
" \n",
" \n",
" 1 \n",
" 462643 \n",
" Female \n",
" Yes \n",
" 38 \n",
" Yes \n",
" Engineer \n",
" NaN \n",
" Average \n",
" 3.0 \n",
" Cat_4 \n",
" A \n",
" \n",
" \n",
" 2 \n",
" 466315 \n",
" Female \n",
" Yes \n",
" 67 \n",
" Yes \n",
" Engineer \n",
" 1.0 \n",
" Low \n",
" 1.0 \n",
" Cat_6 \n",
" B \n",
" \n",
" \n",
" 3 \n",
" 461735 \n",
" Male \n",
" Yes \n",
" 67 \n",
" Yes \n",
" Lawyer \n",
" 0.0 \n",
" High \n",
" 2.0 \n",
" Cat_6 \n",
" B \n",
" \n",
" \n",
" 4 \n",
" 462669 \n",
" Female \n",
" Yes \n",
" 40 \n",
" Yes \n",
" Entertainment \n",
" NaN \n",
" High \n",
" 6.0 \n",
" Cat_6 \n",
" A \n",
" \n",
" \n",
" 5 \n",
" 461319 \n",
" Male \n",
" Yes \n",
" 56 \n",
" No \n",
" Artist \n",
" 0.0 \n",
" Average \n",
" 2.0 \n",
" Cat_6 \n",
" C \n",
" \n",
" \n",
" 6 \n",
" 460156 \n",
" Male \n",
" No \n",
" 32 \n",
" Yes \n",
" Healthcare \n",
" 1.0 \n",
" Low \n",
" 3.0 \n",
" Cat_6 \n",
" C \n",
" \n",
" \n",
" 7 \n",
" 464347 \n",
" Female \n",
" No \n",
" 33 \n",
" Yes \n",
" Healthcare \n",
" 1.0 \n",
" Low \n",
" 3.0 \n",
" Cat_6 \n",
" D \n",
" \n",
" \n",
" 8 \n",
" 465015 \n",
" Female \n",
" Yes \n",
" 61 \n",
" Yes \n",
" Engineer \n",
" 0.0 \n",
" Low \n",
" 3.0 \n",
" Cat_7 \n",
" D \n",
" \n",
" \n",
" 9 \n",
" 465176 \n",
" Female \n",
" Yes \n",
" 55 \n",
" Yes \n",
" Artist \n",
" 1.0 \n",
" Average \n",
" 4.0 \n",
" Cat_6 \n",
" C \n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
"
\n"
],
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "dataframe",
"variable_name": "df",
"summary": "{\n \"name\": \"df\",\n \"rows\": 8068,\n \"fields\": [\n {\n \"column\": \"ID\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 2595,\n \"min\": 458982,\n \"max\": 467974,\n \"num_unique_values\": 8068,\n \"samples\": [\n 467287,\n 466142,\n 465257\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Gender\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 2,\n \"samples\": [\n \"Female\",\n \"Male\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Ever_Married\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 2,\n \"samples\": [\n \"Yes\",\n \"No\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Age\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 16,\n \"min\": 18,\n \"max\": 89,\n \"num_unique_values\": 67,\n \"samples\": [\n 30,\n 49\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Graduated\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 2,\n \"samples\": [\n \"Yes\",\n \"No\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Profession\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 9,\n \"samples\": [\n \"Homemaker\",\n \"Engineer\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Work_Experience\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 3.406762985458083,\n \"min\": 0.0,\n \"max\": 14.0,\n \"num_unique_values\": 15,\n \"samples\": [\n 14.0,\n 2.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Spending_Score\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 3,\n \"samples\": [\n \"Low\",\n \"Average\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Family_Size\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1.5314132820253756,\n \"min\": 1.0,\n \"max\": 9.0,\n \"num_unique_values\": 9,\n \"samples\": [\n 7.0,\n 3.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Var_1\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 7,\n \"samples\": [\n \"Cat_4\",\n \"Cat_6\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Segmentation\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 4,\n \"samples\": [\n \"A\",\n \"C\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}"
}
},
"metadata": {},
"execution_count": 2
}
]
},
{
"cell_type": "code",
"source": [
"print ('Number of samples: ',len(df))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "BDQiCoXjye8W",
"outputId": "001a19df-be51-4df6-beb2-5796c56fa1a4"
},
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Number of samples: 8068\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Looking the bigger picture\n",
"df.info()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3bNPS7gRxAsr",
"outputId": "d63534eb-44a3-448a-a915-b8fa24a96fdb"
},
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"RangeIndex: 8068 entries, 0 to 8067\n",
"Data columns (total 11 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 ID 8068 non-null int64 \n",
" 1 Gender 8068 non-null object \n",
" 2 Ever_Married 7928 non-null object \n",
" 3 Age 8068 non-null int64 \n",
" 4 Graduated 7990 non-null object \n",
" 5 Profession 7944 non-null object \n",
" 6 Work_Experience 7239 non-null float64\n",
" 7 Spending_Score 8068 non-null object \n",
" 8 Family_Size 7733 non-null float64\n",
" 9 Var_1 7992 non-null object \n",
" 10 Segmentation 8068 non-null object \n",
"dtypes: float64(2), int64(2), object(7)\n",
"memory usage: 693.5+ KB\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"def fill_missing_values(df):\n",
" # Replace missing values for numeric columns with median\n",
" numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns\n",
" for col in numeric_cols:\n",
" df[col].fillna(df[col].median(), inplace=True)\n",
"\n",
" # Replace missing values for categorical columns with mode\n",
" categorical_cols = df.select_dtypes(include=['object']).columns\n",
" for col in categorical_cols:\n",
" df[col].fillna(df[col].mode()[0], inplace=True)\n",
"\n",
" # Check if all missing values are filled\n",
" if df.isnull().sum().sum() == 0:\n",
" print(\"All missing values have been replaced.\")\n",
" else:\n",
" print(\"Some missing values remain.\")\n",
"\n",
" return df\n"
],
"metadata": {
"id": "Leot91jg6FCt"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"source": [
"df = fill_missing_values(df)\n",
"df.isnull().sum()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 447
},
"id": "LFegjW9k6Pj0",
"outputId": "b2e5e3c5-a274-42f1-e14f-6c7a7b21e8d7"
},
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"All missing values have been replaced.\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"ID 0\n",
"Gender 0\n",
"Ever_Married 0\n",
"Age 0\n",
"Graduated 0\n",
"Profession 0\n",
"Work_Experience 0\n",
"Spending_Score 0\n",
"Family_Size 0\n",
"Var_1 0\n",
"Segmentation 0\n",
"dtype: int64"
],
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" 0 \n",
" \n",
" \n",
" \n",
" \n",
" ID \n",
" 0 \n",
" \n",
" \n",
" Gender \n",
" 0 \n",
" \n",
" \n",
" Ever_Married \n",
" 0 \n",
" \n",
" \n",
" Age \n",
" 0 \n",
" \n",
" \n",
" Graduated \n",
" 0 \n",
" \n",
" \n",
" Profession \n",
" 0 \n",
" \n",
" \n",
" Work_Experience \n",
" 0 \n",
" \n",
" \n",
" Spending_Score \n",
" 0 \n",
" \n",
" \n",
" Family_Size \n",
" 0 \n",
" \n",
" \n",
" Var_1 \n",
" 0 \n",
" \n",
" \n",
" Segmentation \n",
" 0 \n",
" \n",
" \n",
"
\n",
"
dtype: int64 "
]
},
"metadata": {},
"execution_count": 6
}
]
},
{
"cell_type": "code",
"source": [
"df.info()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4FBrfedp8zzn",
"outputId": "e036366e-6a19-41aa-9803-19c87e4b6072"
},
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"RangeIndex: 8068 entries, 0 to 8067\n",
"Data columns (total 11 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 ID 8068 non-null int64 \n",
" 1 Gender 8068 non-null object \n",
" 2 Ever_Married 8068 non-null object \n",
" 3 Age 8068 non-null int64 \n",
" 4 Graduated 8068 non-null object \n",
" 5 Profession 8068 non-null object \n",
" 6 Work_Experience 8068 non-null float64\n",
" 7 Spending_Score 8068 non-null object \n",
" 8 Family_Size 8068 non-null float64\n",
" 9 Var_1 8068 non-null object \n",
" 10 Segmentation 8068 non-null object \n",
"dtypes: float64(2), int64(2), object(7)\n",
"memory usage: 693.5+ KB\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"dfOnlyFeatures = df.drop(columns=['Segmentation', 'ID'])\n",
"\n",
"# Verify the structure of the new dataframe\n",
"dfOnlyFeatures.info()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "NUtFMhqO9xa-",
"outputId": "63ed3087-fbb8-4d6d-9362-8a137b0717af"
},
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"RangeIndex: 8068 entries, 0 to 8067\n",
"Data columns (total 9 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 Gender 8068 non-null object \n",
" 1 Ever_Married 8068 non-null object \n",
" 2 Age 8068 non-null int64 \n",
" 3 Graduated 8068 non-null object \n",
" 4 Profession 8068 non-null object \n",
" 5 Work_Experience 8068 non-null float64\n",
" 6 Spending_Score 8068 non-null object \n",
" 7 Family_Size 8068 non-null float64\n",
" 8 Var_1 8068 non-null object \n",
"dtypes: float64(2), int64(1), object(6)\n",
"memory usage: 567.4+ KB\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Dummification of categorical variables\n",
"# Dummify (One-Hot Encode) the categorical variables\n",
"df_dummified = pd.get_dummies(dfOnlyFeatures)\n",
"\n",
"# Display the first few rows of the dummified dataset\n",
"df_dummified.head()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 255
},
"id": "jG0Tr-qC7RgG",
"outputId": "77358ed6-cb37-4e72-df8b-fba739e3095e"
},
"execution_count": 10,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Age Work_Experience Family_Size Gender_Female Gender_Male \\\n",
"0 22 1.0 4.0 False True \n",
"1 38 1.0 3.0 True False \n",
"2 67 1.0 1.0 True False \n",
"3 67 0.0 2.0 False True \n",
"4 40 1.0 6.0 True False \n",
"\n",
" Ever_Married_No Ever_Married_Yes Graduated_No Graduated_Yes \\\n",
"0 True False True False \n",
"1 False True False True \n",
"2 False True False True \n",
"3 False True False True \n",
"4 False True False True \n",
"\n",
" Profession_Artist ... Spending_Score_Average Spending_Score_High \\\n",
"0 False ... False False \n",
"1 False ... True False \n",
"2 False ... False False \n",
"3 False ... False True \n",
"4 False ... False True \n",
"\n",
" Spending_Score_Low Var_1_Cat_1 Var_1_Cat_2 Var_1_Cat_3 Var_1_Cat_4 \\\n",
"0 True False False False True \n",
"1 False False False False True \n",
"2 True False False False False \n",
"3 False False False False False \n",
"4 False False False False False \n",
"\n",
" Var_1_Cat_5 Var_1_Cat_6 Var_1_Cat_7 \n",
"0 False False False \n",
"1 False False False \n",
"2 False True False \n",
"3 False True False \n",
"4 False True False \n",
"\n",
"[5 rows x 28 columns]"
],
"text/html": [
"\n",
" \n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Age \n",
" Work_Experience \n",
" Family_Size \n",
" Gender_Female \n",
" Gender_Male \n",
" Ever_Married_No \n",
" Ever_Married_Yes \n",
" Graduated_No \n",
" Graduated_Yes \n",
" Profession_Artist \n",
" ... \n",
" Spending_Score_Average \n",
" Spending_Score_High \n",
" Spending_Score_Low \n",
" Var_1_Cat_1 \n",
" Var_1_Cat_2 \n",
" Var_1_Cat_3 \n",
" Var_1_Cat_4 \n",
" Var_1_Cat_5 \n",
" Var_1_Cat_6 \n",
" Var_1_Cat_7 \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" 22 \n",
" 1.0 \n",
" 4.0 \n",
" False \n",
" True \n",
" True \n",
" False \n",
" True \n",
" False \n",
" False \n",
" ... \n",
" False \n",
" False \n",
" True \n",
" False \n",
" False \n",
" False \n",
" True \n",
" False \n",
" False \n",
" False \n",
" \n",
" \n",
" 1 \n",
" 38 \n",
" 1.0 \n",
" 3.0 \n",
" True \n",
" False \n",
" False \n",
" True \n",
" False \n",
" True \n",
" False \n",
" ... \n",
" True \n",
" False \n",
" False \n",
" False \n",
" False \n",
" False \n",
" True \n",
" False \n",
" False \n",
" False \n",
" \n",
" \n",
" 2 \n",
" 67 \n",
" 1.0 \n",
" 1.0 \n",
" True \n",
" False \n",
" False \n",
" True \n",
" False \n",
" True \n",
" False \n",
" ... \n",
" False \n",
" False \n",
" True \n",
" False \n",
" False \n",
" False \n",
" False \n",
" False \n",
" True \n",
" False \n",
" \n",
" \n",
" 3 \n",
" 67 \n",
" 0.0 \n",
" 2.0 \n",
" False \n",
" True \n",
" False \n",
" True \n",
" False \n",
" True \n",
" False \n",
" ... \n",
" False \n",
" True \n",
" False \n",
" False \n",
" False \n",
" False \n",
" False \n",
" False \n",
" True \n",
" False \n",
" \n",
" \n",
" 4 \n",
" 40 \n",
" 1.0 \n",
" 6.0 \n",
" True \n",
" False \n",
" False \n",
" True \n",
" False \n",
" True \n",
" False \n",
" ... \n",
" False \n",
" True \n",
" False \n",
" False \n",
" False \n",
" False \n",
" False \n",
" False \n",
" True \n",
" False \n",
" \n",
" \n",
"
\n",
"
5 rows × 28 columns
\n",
"
\n",
"
\n",
"
\n"
],
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "dataframe",
"variable_name": "df_dummified"
}
},
"metadata": {},
"execution_count": 10
}
]
},
{
"cell_type": "code",
"source": [
"df1 = df_dummified.copy()\n",
"df1.head()\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 255
},
"id": "X19Ztv4vP6Mk",
"outputId": "78f7aeaa-0111-45ba-f233-990fe9092789"
},
"execution_count": 11,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Age Work_Experience Family_Size Gender_Female Gender_Male \\\n",
"0 22 1.0 4.0 False True \n",
"1 38 1.0 3.0 True False \n",
"2 67 1.0 1.0 True False \n",
"3 67 0.0 2.0 False True \n",
"4 40 1.0 6.0 True False \n",
"\n",
" Ever_Married_No Ever_Married_Yes Graduated_No Graduated_Yes \\\n",
"0 True False True False \n",
"1 False True False True \n",
"2 False True False True \n",
"3 False True False True \n",
"4 False True False True \n",
"\n",
" Profession_Artist ... Spending_Score_Average Spending_Score_High \\\n",
"0 False ... False False \n",
"1 False ... True False \n",
"2 False ... False False \n",
"3 False ... False True \n",
"4 False ... False True \n",
"\n",
" Spending_Score_Low Var_1_Cat_1 Var_1_Cat_2 Var_1_Cat_3 Var_1_Cat_4 \\\n",
"0 True False False False True \n",
"1 False False False False True \n",
"2 True False False False False \n",
"3 False False False False False \n",
"4 False False False False False \n",
"\n",
" Var_1_Cat_5 Var_1_Cat_6 Var_1_Cat_7 \n",
"0 False False False \n",
"1 False False False \n",
"2 False True False \n",
"3 False True False \n",
"4 False True False \n",
"\n",
"[5 rows x 28 columns]"
],
"text/html": [
"\n",
" \n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Age \n",
" Work_Experience \n",
" Family_Size \n",
" Gender_Female \n",
" Gender_Male \n",
" Ever_Married_No \n",
" Ever_Married_Yes \n",
" Graduated_No \n",
" Graduated_Yes \n",
" Profession_Artist \n",
" ... \n",
" Spending_Score_Average \n",
" Spending_Score_High \n",
" Spending_Score_Low \n",
" Var_1_Cat_1 \n",
" Var_1_Cat_2 \n",
" Var_1_Cat_3 \n",
" Var_1_Cat_4 \n",
" Var_1_Cat_5 \n",
" Var_1_Cat_6 \n",
" Var_1_Cat_7 \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" 22 \n",
" 1.0 \n",
" 4.0 \n",
" False \n",
" True \n",
" True \n",
" False \n",
" True \n",
" False \n",
" False \n",
" ... \n",
" False \n",
" False \n",
" True \n",
" False \n",
" False \n",
" False \n",
" True \n",
" False \n",
" False \n",
" False \n",
" \n",
" \n",
" 1 \n",
" 38 \n",
" 1.0 \n",
" 3.0 \n",
" True \n",
" False \n",
" False \n",
" True \n",
" False \n",
" True \n",
" False \n",
" ... \n",
" True \n",
" False \n",
" False \n",
" False \n",
" False \n",
" False \n",
" True \n",
" False \n",
" False \n",
" False \n",
" \n",
" \n",
" 2 \n",
" 67 \n",
" 1.0 \n",
" 1.0 \n",
" True \n",
" False \n",
" False \n",
" True \n",
" False \n",
" True \n",
" False \n",
" ... \n",
" False \n",
" False \n",
" True \n",
" False \n",
" False \n",
" False \n",
" False \n",
" False \n",
" True \n",
" False \n",
" \n",
" \n",
" 3 \n",
" 67 \n",
" 0.0 \n",
" 2.0 \n",
" False \n",
" True \n",
" False \n",
" True \n",
" False \n",
" True \n",
" False \n",
" ... \n",
" False \n",
" True \n",
" False \n",
" False \n",
" False \n",
" False \n",
" False \n",
" False \n",
" True \n",
" False \n",
" \n",
" \n",
" 4 \n",
" 40 \n",
" 1.0 \n",
" 6.0 \n",
" True \n",
" False \n",
" False \n",
" True \n",
" False \n",
" True \n",
" False \n",
" ... \n",
" False \n",
" True \n",
" False \n",
" False \n",
" False \n",
" False \n",
" False \n",
" False \n",
" True \n",
" False \n",
" \n",
" \n",
"
\n",
"
5 rows × 28 columns
\n",
"
\n",
"
\n",
"
\n"
],
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "dataframe",
"variable_name": "df1"
}
},
"metadata": {},
"execution_count": 11
}
]
},
{
"cell_type": "code",
"source": [
"df1.info()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6SCFyiUgCgd5",
"outputId": "1db75544-70f2-4047-cafc-1d39d92ad2d8"
},
"execution_count": 23,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"RangeIndex: 8068 entries, 0 to 8067\n",
"Data columns (total 29 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 Age 8068 non-null int64 \n",
" 1 Work_Experience 8068 non-null float64\n",
" 2 Family_Size 8068 non-null float64\n",
" 3 Gender_Female 8068 non-null bool \n",
" 4 Gender_Male 8068 non-null bool \n",
" 5 Ever_Married_No 8068 non-null bool \n",
" 6 Ever_Married_Yes 8068 non-null bool \n",
" 7 Graduated_No 8068 non-null bool \n",
" 8 Graduated_Yes 8068 non-null bool \n",
" 9 Profession_Artist 8068 non-null bool \n",
" 10 Profession_Doctor 8068 non-null bool \n",
" 11 Profession_Engineer 8068 non-null bool \n",
" 12 Profession_Entertainment 8068 non-null bool \n",
" 13 Profession_Executive 8068 non-null bool \n",
" 14 Profession_Healthcare 8068 non-null bool \n",
" 15 Profession_Homemaker 8068 non-null bool \n",
" 16 Profession_Lawyer 8068 non-null bool \n",
" 17 Profession_Marketing 8068 non-null bool \n",
" 18 Spending_Score_Average 8068 non-null bool \n",
" 19 Spending_Score_High 8068 non-null bool \n",
" 20 Spending_Score_Low 8068 non-null bool \n",
" 21 Var_1_Cat_1 8068 non-null bool \n",
" 22 Var_1_Cat_2 8068 non-null bool \n",
" 23 Var_1_Cat_3 8068 non-null bool \n",
" 24 Var_1_Cat_4 8068 non-null bool \n",
" 25 Var_1_Cat_5 8068 non-null bool \n",
" 26 Var_1_Cat_6 8068 non-null bool \n",
" 27 Var_1_Cat_7 8068 non-null bool \n",
" 28 Segmentation 8068 non-null int64 \n",
"dtypes: bool(25), float64(2), int64(2)\n",
"memory usage: 449.2 KB\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Label encode the target variable\n",
"from sklearn.preprocessing import LabelEncoder\n",
"\n",
"# Initialize the LabelEncoder\n",
"label_encoder = LabelEncoder()\n",
"\n",
"# Perform label encoding on the 'Segmentation' column\n",
"df['Segmentation'] = label_encoder.fit_transform(df['Segmentation'])\n",
"\n",
"# Mapping of original classes to encoded values\n",
"label_mapping = dict(zip(label_encoder.classes_, label_encoder.transform(label_encoder.classes_)))\n",
"\n",
"df1['Segmentation'] = df['Segmentation']\n",
"df1.head()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 255
},
"id": "KcPmlQ07-r0k",
"outputId": "351e7cdb-b410-488d-e913-704deda84159"
},
"execution_count": 24,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Age Work_Experience Family_Size Gender_Female Gender_Male \\\n",
"0 22 1.0 4.0 False True \n",
"1 38 1.0 3.0 True False \n",
"2 67 1.0 1.0 True False \n",
"3 67 0.0 2.0 False True \n",
"4 40 1.0 6.0 True False \n",
"\n",
" Ever_Married_No Ever_Married_Yes Graduated_No Graduated_Yes \\\n",
"0 True False True False \n",
"1 False True False True \n",
"2 False True False True \n",
"3 False True False True \n",
"4 False True False True \n",
"\n",
" Profession_Artist ... Spending_Score_High Spending_Score_Low \\\n",
"0 False ... False True \n",
"1 False ... False False \n",
"2 False ... False True \n",
"3 False ... True False \n",
"4 False ... True False \n",
"\n",
" Var_1_Cat_1 Var_1_Cat_2 Var_1_Cat_3 Var_1_Cat_4 Var_1_Cat_5 \\\n",
"0 False False False True False \n",
"1 False False False True False \n",
"2 False False False False False \n",
"3 False False False False False \n",
"4 False False False False False \n",
"\n",
" Var_1_Cat_6 Var_1_Cat_7 Segmentation \n",
"0 False False 3 \n",
"1 False False 0 \n",
"2 True False 1 \n",
"3 True False 1 \n",
"4 True False 0 \n",
"\n",
"[5 rows x 29 columns]"
],
"text/html": [
"\n",
" \n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Age \n",
" Work_Experience \n",
" Family_Size \n",
" Gender_Female \n",
" Gender_Male \n",
" Ever_Married_No \n",
" Ever_Married_Yes \n",
" Graduated_No \n",
" Graduated_Yes \n",
" Profession_Artist \n",
" ... \n",
" Spending_Score_High \n",
" Spending_Score_Low \n",
" Var_1_Cat_1 \n",
" Var_1_Cat_2 \n",
" Var_1_Cat_3 \n",
" Var_1_Cat_4 \n",
" Var_1_Cat_5 \n",
" Var_1_Cat_6 \n",
" Var_1_Cat_7 \n",
" Segmentation \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" 22 \n",
" 1.0 \n",
" 4.0 \n",
" False \n",
" True \n",
" True \n",
" False \n",
" True \n",
" False \n",
" False \n",
" ... \n",
" False \n",
" True \n",
" False \n",
" False \n",
" False \n",
" True \n",
" False \n",
" False \n",
" False \n",
" 3 \n",
" \n",
" \n",
" 1 \n",
" 38 \n",
" 1.0 \n",
" 3.0 \n",
" True \n",
" False \n",
" False \n",
" True \n",
" False \n",
" True \n",
" False \n",
" ... \n",
" False \n",
" False \n",
" False \n",
" False \n",
" False \n",
" True \n",
" False \n",
" False \n",
" False \n",
" 0 \n",
" \n",
" \n",
" 2 \n",
" 67 \n",
" 1.0 \n",
" 1.0 \n",
" True \n",
" False \n",
" False \n",
" True \n",
" False \n",
" True \n",
" False \n",
" ... \n",
" False \n",
" True \n",
" False \n",
" False \n",
" False \n",
" False \n",
" False \n",
" True \n",
" False \n",
" 1 \n",
" \n",
" \n",
" 3 \n",
" 67 \n",
" 0.0 \n",
" 2.0 \n",
" False \n",
" True \n",
" False \n",
" True \n",
" False \n",
" True \n",
" False \n",
" ... \n",
" True \n",
" False \n",
" False \n",
" False \n",
" False \n",
" False \n",
" False \n",
" True \n",
" False \n",
" 1 \n",
" \n",
" \n",
" 4 \n",
" 40 \n",
" 1.0 \n",
" 6.0 \n",
" True \n",
" False \n",
" False \n",
" True \n",
" False \n",
" True \n",
" False \n",
" ... \n",
" True \n",
" False \n",
" False \n",
" False \n",
" False \n",
" False \n",
" False \n",
" True \n",
" False \n",
" 0 \n",
" \n",
" \n",
"
\n",
"
5 rows × 29 columns
\n",
"
\n",
"
\n",
"
\n"
],
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "dataframe",
"variable_name": "df1"
}
},
"metadata": {},
"execution_count": 24
}
]
},
{
"cell_type": "code",
"source": [
"print(label_mapping)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "HPq0Lbrr_e8G",
"outputId": "14a7a43c-0531-427f-be55-28ead25d8865"
},
"execution_count": 25,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"{0: 0, 1: 1, 2: 2, 3: 3}\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Separating dependent-independent variables\n",
"X = df1.drop(['Segmentation'], axis=1)\n",
"y = df1['Segmentation']\n",
"X.head(2)"
],
"metadata": {
"id": "LukvBGXgQEbA",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 161
},
"outputId": "fc4095d3-5d42-42a1-a756-9e27a7d8cfad"
},
"execution_count": 16,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Age Work_Experience Family_Size Gender_Female Gender_Male \\\n",
"0 22 1.0 4.0 False True \n",
"1 38 1.0 3.0 True False \n",
"\n",
" Ever_Married_No Ever_Married_Yes Graduated_No Graduated_Yes \\\n",
"0 True False True False \n",
"1 False True False True \n",
"\n",
" Profession_Artist ... Spending_Score_Average Spending_Score_High \\\n",
"0 False ... False False \n",
"1 False ... True False \n",
"\n",
" Spending_Score_Low Var_1_Cat_1 Var_1_Cat_2 Var_1_Cat_3 Var_1_Cat_4 \\\n",
"0 True False False False True \n",
"1 False False False False True \n",
"\n",
" Var_1_Cat_5 Var_1_Cat_6 Var_1_Cat_7 \n",
"0 False False False \n",
"1 False False False \n",
"\n",
"[2 rows x 28 columns]"
],
"text/html": [
"\n",
" \n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Age \n",
" Work_Experience \n",
" Family_Size \n",
" Gender_Female \n",
" Gender_Male \n",
" Ever_Married_No \n",
" Ever_Married_Yes \n",
" Graduated_No \n",
" Graduated_Yes \n",
" Profession_Artist \n",
" ... \n",
" Spending_Score_Average \n",
" Spending_Score_High \n",
" Spending_Score_Low \n",
" Var_1_Cat_1 \n",
" Var_1_Cat_2 \n",
" Var_1_Cat_3 \n",
" Var_1_Cat_4 \n",
" Var_1_Cat_5 \n",
" Var_1_Cat_6 \n",
" Var_1_Cat_7 \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" 22 \n",
" 1.0 \n",
" 4.0 \n",
" False \n",
" True \n",
" True \n",
" False \n",
" True \n",
" False \n",
" False \n",
" ... \n",
" False \n",
" False \n",
" True \n",
" False \n",
" False \n",
" False \n",
" True \n",
" False \n",
" False \n",
" False \n",
" \n",
" \n",
" 1 \n",
" 38 \n",
" 1.0 \n",
" 3.0 \n",
" True \n",
" False \n",
" False \n",
" True \n",
" False \n",
" True \n",
" False \n",
" ... \n",
" True \n",
" False \n",
" False \n",
" False \n",
" False \n",
" False \n",
" True \n",
" False \n",
" False \n",
" False \n",
" \n",
" \n",
"
\n",
"
2 rows × 28 columns
\n",
"
\n",
"
\n",
"
\n"
],
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "dataframe",
"variable_name": "X"
}
},
"metadata": {},
"execution_count": 16
}
]
},
{
"cell_type": "code",
"source": [
"# import the train-test split\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"# divide into train and test sets\n",
"trainX, testX, train_y, test_y = train_test_split(X,y, train_size = 0.8, random_state = 101, stratify=y)\n",
"trainX.shape, trainY.shape, testX.shape, testY.shape"
],
"metadata": {
"id": "D4s7KWkVQItR",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "55fbb4d1-fadc-432b-b19f-a5d6594dc78b"
},
"execution_count": 29,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"((6454, 28), (6454,), (1614, 28), (1614,))"
]
},
"metadata": {},
"execution_count": 29
}
]
},
{
"cell_type": "code",
"source": [
"# Correlation matrix\n",
"# Select numeric columns (float64 and int64 types) from the dataset\n",
"numeric_cols = df1.drop(columns=['Segmentation']).select_dtypes(include=['float64', 'int64']).columns\n",
"# Extract only numeric columns into a new dataframe\n",
"df_numeric_only = df1[numeric_cols]\n",
"\n",
"plt.figure(figsize=(7,5))\n",
"sns.heatmap(df_numeric_only.corr(method='spearman').round(2),linewidth = 0.5,annot=True,cmap=\"YlGnBu\")\n",
"plt.show()\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 451
},
"id": "kdoBZMfNRXoK",
"outputId": "5577d291-36ed-4a3b-8f12-665005158ab0"
},
"execution_count": 30,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"**Model Building**"
],
"metadata": {
"id": "KdhBi8paSfyb"
}
},
{
"cell_type": "code",
"source": [
"# train a Gaussian Naive Bayes classifier on the training set\n",
"from sklearn.naive_bayes import GaussianNB\n",
"\n",
"# instantiate the model\n",
"gnb1 = GaussianNB()\n",
"\n",
"# Train model\n",
"model_nb1 = gnb1.fit(trainX, train_y)\n",
"\n",
"# Predicting the classes\n",
"yhat3 = gnb1.predict(trainX)\n",
"\n",
"from sklearn.metrics import confusion_matrix\n",
"cm3 = confusion_matrix(train_y.values, yhat3, labels=[0,1,2,3])\n",
"print('\\n\\n-------The confusion matrix for this model is-------')\n",
"print(cm3)\n",
"\n",
"from sklearn.metrics import classification_report\n",
"print('\\n\\n-------Printing the whole report of the model-------')\n",
"print(classification_report(train_y.values, yhat3))\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9HTy1Bq7SrqV",
"outputId": "2ff1b61a-d9f8-43b3-a206-76b3529079ac"
},
"execution_count": 32,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"\n",
"-------The confusion matrix for this model is-------\n",
"[[ 534 233 464 347]\n",
" [ 288 310 710 178]\n",
" [ 110 164 1088 214]\n",
" [ 310 137 135 1232]]\n",
"\n",
"\n",
"-------Printing the whole report of the model-------\n",
" precision recall f1-score support\n",
"\n",
" 0 0.43 0.34 0.38 1578\n",
" 1 0.37 0.21 0.27 1486\n",
" 2 0.45 0.69 0.55 1576\n",
" 3 0.63 0.68 0.65 1814\n",
"\n",
" accuracy 0.49 6454\n",
" macro avg 0.47 0.48 0.46 6454\n",
"weighted avg 0.48 0.49 0.47 6454\n",
"\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"test_nb1_x = testX.copy()\n",
"test_nb1_x.head()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 255
},
"id": "D8jF3RmqSxHZ",
"outputId": "9dc4ad2e-ddd3-495a-e76a-2978a3792f27"
},
"execution_count": 33,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Age Work_Experience Family_Size Gender_Female Gender_Male \\\n",
"4463 18 1.0 3.0 True False \n",
"1687 38 0.0 3.0 False True \n",
"5694 43 0.0 2.0 True False \n",
"7390 63 0.0 2.0 True False \n",
"1347 35 1.0 1.0 True False \n",
"\n",
" Ever_Married_No Ever_Married_Yes Graduated_No Graduated_Yes \\\n",
"4463 True False True False \n",
"1687 False True True False \n",
"5694 True False True False \n",
"7390 False True True False \n",
"1347 True False False True \n",
"\n",
" Profession_Artist ... Spending_Score_Average Spending_Score_High \\\n",
"4463 False ... False False \n",
"1687 False ... False True \n",
"5694 False ... False False \n",
"7390 True ... False True \n",
"1347 False ... False False \n",
"\n",
" Spending_Score_Low Var_1_Cat_1 Var_1_Cat_2 Var_1_Cat_3 Var_1_Cat_4 \\\n",
"4463 True False False True False \n",
"1687 False False False False False \n",
"5694 True False False False True \n",
"7390 False False False False False \n",
"1347 True False False False True \n",
"\n",
" Var_1_Cat_5 Var_1_Cat_6 Var_1_Cat_7 \n",
"4463 False False False \n",
"1687 False True False \n",
"5694 False False False \n",
"7390 False True False \n",
"1347 False False False \n",
"\n",
"[5 rows x 28 columns]"
],
"text/html": [
"\n",
" \n",
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Age \n",
" Work_Experience \n",
" Family_Size \n",
" Gender_Female \n",
" Gender_Male \n",
" Ever_Married_No \n",
" Ever_Married_Yes \n",
" Graduated_No \n",
" Graduated_Yes \n",
" Profession_Artist \n",
" ... \n",
" Spending_Score_Average \n",
" Spending_Score_High \n",
" Spending_Score_Low \n",
" Var_1_Cat_1 \n",
" Var_1_Cat_2 \n",
" Var_1_Cat_3 \n",
" Var_1_Cat_4 \n",
" Var_1_Cat_5 \n",
" Var_1_Cat_6 \n",
" Var_1_Cat_7 \n",
" \n",
" \n",
" \n",
" \n",
" 4463 \n",
" 18 \n",
" 1.0 \n",
" 3.0 \n",
" True \n",
" False \n",
" True \n",
" False \n",
" True \n",
" False \n",
" False \n",
" ... \n",
" False \n",
" False \n",
" True \n",
" False \n",
" False \n",
" True \n",
" False \n",
" False \n",
" False \n",
" False \n",
" \n",
" \n",
" 1687 \n",
" 38 \n",
" 0.0 \n",
" 3.0 \n",
" False \n",
" True \n",
" False \n",
" True \n",
" True \n",
" False \n",
" False \n",
" ... \n",
" False \n",
" True \n",
" False \n",
" False \n",
" False \n",
" False \n",
" False \n",
" False \n",
" True \n",
" False \n",
" \n",
" \n",
" 5694 \n",
" 43 \n",
" 0.0 \n",
" 2.0 \n",
" True \n",
" False \n",
" True \n",
" False \n",
" True \n",
" False \n",
" False \n",
" ... \n",
" False \n",
" False \n",
" True \n",
" False \n",
" False \n",
" False \n",
" True \n",
" False \n",
" False \n",
" False \n",
" \n",
" \n",
" 7390 \n",
" 63 \n",
" 0.0 \n",
" 2.0 \n",
" True \n",
" False \n",
" False \n",
" True \n",
" True \n",
" False \n",
" True \n",
" ... \n",
" False \n",
" True \n",
" False \n",
" False \n",
" False \n",
" False \n",
" False \n",
" False \n",
" True \n",
" False \n",
" \n",
" \n",
" 1347 \n",
" 35 \n",
" 1.0 \n",
" 1.0 \n",
" True \n",
" False \n",
" True \n",
" False \n",
" False \n",
" True \n",
" False \n",
" ... \n",
" False \n",
" False \n",
" True \n",
" False \n",
" False \n",
" False \n",
" True \n",
" False \n",
" False \n",
" False \n",
" \n",
" \n",
"
\n",
"
5 rows × 28 columns
\n",
"
\n",
"
\n",
"
\n"
],
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "dataframe",
"variable_name": "test_nb1_x"
}
},
"metadata": {},
"execution_count": 33
}
]
},
{
"cell_type": "code",
"source": [
"test_nb1_y = test_y.copy()\n",
"test_nb1_y.head()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 241
},
"id": "q5lF1dYuS0aW",
"outputId": "05c075d1-04b2-4d10-e76f-3a909511f53e"
},
"execution_count": 34,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"4463 3\n",
"1687 2\n",
"5694 3\n",
"7390 1\n",
"1347 3\n",
"Name: Segmentation, dtype: int64"
],
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Segmentation \n",
" \n",
" \n",
" \n",
" \n",
" 4463 \n",
" 3 \n",
" \n",
" \n",
" 1687 \n",
" 2 \n",
" \n",
" \n",
" 5694 \n",
" 3 \n",
" \n",
" \n",
" 7390 \n",
" 1 \n",
" \n",
" \n",
" 1347 \n",
" 3 \n",
" \n",
" \n",
"
\n",
"
dtype: int64 "
]
},
"metadata": {},
"execution_count": 34
}
]
},
{
"cell_type": "code",
"source": [
"# apply gnb1 prediction on test_nb1_x\n",
"\n",
"y_nb1 = gnb1.predict(test_nb1_x)\n",
"y_nb1\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Zg2xzA1rUVJO",
"outputId": "cc2d0a7a-4321-4fdf-9ebe-a9377b342200"
},
"execution_count": 35,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([3, 1, 3, ..., 1, 1, 0])"
]
},
"metadata": {},
"execution_count": 35
}
]
},
{
"cell_type": "code",
"source": [
"pd.Series(y_nb1).value_counts()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 209
},
"id": "IiQspLmxYAvj",
"outputId": "ff970bbc-eb75-4480-c6c7-5c11abc6443b"
},
"execution_count": 36,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"2 569\n",
"3 517\n",
"0 321\n",
"1 207\n",
"Name: count, dtype: int64"
],
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" count \n",
" \n",
" \n",
" \n",
" \n",
" 2 \n",
" 569 \n",
" \n",
" \n",
" 3 \n",
" 517 \n",
" \n",
" \n",
" 0 \n",
" 321 \n",
" \n",
" \n",
" 1 \n",
" 207 \n",
" \n",
" \n",
"
\n",
"
dtype: int64 "
]
},
"metadata": {},
"execution_count": 36
}
]
},
{
"cell_type": "code",
"source": [
"from sklearn.metrics import confusion_matrix\n",
"print('-------The confusion matrix for test data is-------\\n')\n",
"print(confusion_matrix(test_nb1_y.values, y_nb1, labels=[0,1,2,3]))\n",
"\n",
"from sklearn.metrics import classification_report\n",
"print('\\n\\n-------Printing the report of test data-------\\n')\n",
"print(classification_report(test_nb1_y.values, y_nb1))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5HuB7o5pYON2",
"outputId": "e838f65c-1d28-4b6b-cf6e-b5ebd93cf27e"
},
"execution_count": 38,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"-------The confusion matrix for test data is-------\n",
"\n",
"[[147 54 113 80]\n",
" [ 78 72 167 55]\n",
" [ 28 44 253 69]\n",
" [ 68 37 36 313]]\n",
"\n",
"\n",
"-------Printing the report of test data-------\n",
"\n",
" precision recall f1-score support\n",
"\n",
" 0 0.46 0.37 0.41 394\n",
" 1 0.35 0.19 0.25 372\n",
" 2 0.44 0.64 0.53 394\n",
" 3 0.61 0.69 0.64 454\n",
"\n",
" accuracy 0.49 1614\n",
" macro avg 0.46 0.47 0.46 1614\n",
"weighted avg 0.47 0.49 0.47 1614\n",
"\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"**Model Evaluation**"
],
"metadata": {
"id": "P_6v8yfjYZlL"
}
},
{
"cell_type": "code",
"source": [
"print('************************ MODEL-1 REPORT *********************************\\n')\n",
"print('Train data')\n",
"print(classification_report(train_y.values, yhat3))\n",
"print('\\nTest data')\n",
"print(classification_report(test_nb1_y.values, y_nb1))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "G2WCPEtTYc5_",
"outputId": "86760a4f-04a8-4ad8-83ed-df6c0e6dcecd"
},
"execution_count": 40,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"************************ MODEL-1 REPORT *********************************\n",
"\n",
"Train data\n",
" precision recall f1-score support\n",
"\n",
" 0 0.43 0.34 0.38 1578\n",
" 1 0.37 0.21 0.27 1486\n",
" 2 0.45 0.69 0.55 1576\n",
" 3 0.63 0.68 0.65 1814\n",
"\n",
" accuracy 0.49 6454\n",
" macro avg 0.47 0.48 0.46 6454\n",
"weighted avg 0.48 0.49 0.47 6454\n",
"\n",
"\n",
"Test data\n",
" precision recall f1-score support\n",
"\n",
" 0 0.46 0.37 0.41 394\n",
" 1 0.35 0.19 0.25 372\n",
" 2 0.44 0.64 0.53 394\n",
" 3 0.61 0.69 0.64 454\n",
"\n",
" accuracy 0.49 1614\n",
" macro avg 0.46 0.47 0.46 1614\n",
"weighted avg 0.47 0.49 0.47 1614\n",
"\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"**Analysis of the Results**\n",
"The confusion matrix shown in the image provides key insights into the performance of the GNB model:\n",
"\n",
"For Train Data:\n",
"The F1 scores indicate that Segment D is classified the best, while Segment B performs the worst with a recall of 0.21.\n",
"\n",
"Fo Test Data:\n",
"Accuracy: 0.49 (same as the train data, which shows the model is not overfitting or underfitting drastically)\n",
"\n",
"Key Takeaways:\n",
"\n",
"--Segment D: The model performs best on this segment, achieving high precision, recall, and F1-scores in both the train and test sets.\n",
"\n",
"--Segment B: This segment is the weakest, with very low recall and F1-score, which means the model is struggling to correctly identify and predict this group.\n",
"\n",
"--Overall Accuracy: With an accuracy of 49%, the model isn't performing exceptionally well but is providing useful predictions, especially for Segment D and C.\n"
],
"metadata": {
"id": "2NyC4ja1Z02n"
}
},
{
"cell_type": "markdown",
"source": [
"### In-Class Activity 3: Predict Class Probabilities with Gaussian Naive Bayes\n",
"- Objective: Train a Gaussian Naive Bayes model and predict the probability of each class for a few instances.\n",
"\n",
"#### Steps for the Activity:\n",
"`Train the Model:`\n",
"- Train the Gaussian Naive Bayes model on the training data.`\n",
"\n",
"`Predict Class Probabilities:`\n",
"- Use the trained model to predict class probabilities for a few instances from the test set.\n",
"\n",
"`Interpret the Probabilities:`\n",
"- Print the predicted probabilities and discuss how confident the model is for each class.\n",
"\n",
"`Hint:` mean (`theta_`) and variance (`var_`) can be extracted from the model object."
],
"metadata": {
"id": "0bhLnhf9hgx8"
}
},
{
"cell_type": "code",
"source": [
"## Solution:\n",
"\n",
"import numpy as np\n",
"from sklearn.naive_bayes import GaussianNB\n",
"\n",
"# Step 1: Train the Gaussian Naive Bayes model\n",
"gnb = GaussianNB()\n",
"gnb.fit(trainX, train_y)\n",
"\n",
"# Step 2: Extract the mean and variance of each feature learned by the model\n",
"means = gnb.theta_ # Mean of each feature per class\n",
"variances = gnb.var_ # Variance of each feature per class\n",
"\n",
"# Step 3: Display the means and variances for analysis\n",
"print(\"Feature Means per Class:\")\n",
"print(means)\n",
"\n",
"print(\"\\nFeature Variances per Class:\")\n",
"print(variances)\n",
"\n",
"# Step 4: Brief Interpretation\n",
"# Learners should compare how much each feature varies across the classes to identify important features.\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "y46xEZsujg5z",
"outputId": "5a6bd6e3-9d28-4994-bece-aa4145d30328"
},
"execution_count": 43,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Feature Means per Class:\n",
"[[4.50114068e+01 2.69645120e+00 2.45944233e+00 4.59442332e-01\n",
" 5.40557668e-01 4.03675539e-01 5.96324461e-01 3.75792142e-01\n",
" 6.24207858e-01 3.00380228e-01 1.02661597e-01 1.26742712e-01\n",
" 1.84410646e-01 6.08365019e-02 5.64005070e-02 3.67553866e-02\n",
" 1.00760456e-01 3.10519645e-02 1.75538657e-01 1.41318124e-01\n",
" 6.83143219e-01 1.58428390e-02 4.24588086e-02 1.11533587e-01\n",
" 1.68567807e-01 1.07731305e-02 6.25475285e-01 2.53485425e-02]\n",
" [4.83371467e+01 2.20390310e+00 2.69986541e+00 4.48855989e-01\n",
" 5.51144011e-01 2.52355316e-01 7.47644684e-01 2.65814266e-01\n",
" 7.34185734e-01 4.20592194e-01 7.40242261e-02 1.00942127e-01\n",
" 1.16419919e-01 1.03633917e-01 5.31628533e-02 2.96096904e-02\n",
" 8.61372813e-02 1.54777927e-02 3.18304172e-01 2.07940781e-01\n",
" 4.73755047e-01 1.48048452e-02 5.58546433e-02 9.82503365e-02\n",
" 1.17092867e-01 1.00942127e-02 6.79676985e-01 2.42261104e-02]\n",
" [4.90697970e+01 2.14784264e+00 2.97144670e+00 4.62563452e-01\n",
" 5.37436548e-01 1.97969543e-01 8.02030457e-01 1.69416244e-01\n",
" 8.30583756e-01 5.57106599e-01 7.29695431e-02 3.17258883e-02\n",
" 8.05837563e-02 8.81979695e-02 7.17005076e-02 1.39593909e-02\n",
" 6.72588832e-02 1.64974619e-02 4.66370558e-01 1.98604061e-01\n",
" 3.35025381e-01 1.58629442e-02 4.75888325e-02 7.80456853e-02\n",
" 5.20304569e-02 1.01522843e-02 7.71573604e-01 2.47461929e-02]\n",
" [3.32260198e+01 2.73263506e+00 3.19790518e+00 4.29437707e-01\n",
" 5.70562293e-01 7.10033076e-01 2.89966924e-01 6.35611907e-01\n",
" 3.64388093e-01 8.54465270e-02 8.82028666e-02 7.66262404e-02\n",
" 9.70231533e-02 4.90628445e-02 4.33296582e-01 4.24476295e-02\n",
" 5.34729879e-02 7.44211687e-02 6.00882029e-02 6.89084895e-02\n",
" 8.71003308e-01 2.20507166e-02 6.17420066e-02 1.09151047e-01\n",
" 1.83020948e-01 1.21278942e-02 5.86549063e-01 2.53583241e-02]]\n",
"\n",
"Feature Variances per Class:\n",
"[[2.70632316e+02 1.17399243e+01 2.10640352e+00 2.48355353e-01\n",
" 2.48355353e-01 2.40721876e-01 2.40721876e-01 2.34572686e-01\n",
" 2.34572686e-01 2.10152225e-01 9.21224714e-02 1.10679275e-01\n",
" 1.50403638e-01 5.71356999e-02 5.32197677e-02 3.54047060e-02\n",
" 9.06080646e-02 3.00880179e-02 1.44725115e-01 1.21347590e-01\n",
" 2.16458839e-01 1.55921214e-02 4.06563361e-02 9.90941238e-02\n",
" 1.40152980e-01 1.06573481e-02 2.34256231e-01 2.47062718e-02]\n",
" [2.19846628e+02 9.33325557e+00 1.93952920e+00 2.47384568e-01\n",
" 2.47384568e-01 1.88672389e-01 1.88672389e-01 1.95157320e-01\n",
" 1.95157320e-01 2.43694678e-01 6.85449180e-02 9.07530915e-02\n",
" 1.02866600e-01 9.28942058e-02 5.03368422e-02 2.87332346e-02\n",
" 7.87179280e-02 1.52385086e-02 2.16986904e-01 1.64701690e-01\n",
" 2.49311480e-01 1.45859397e-02 5.27351801e-02 8.85974858e-02\n",
" 1.03382405e-01 9.99259744e-03 2.17716459e-01 2.36394839e-02]\n",
" [2.07552235e+02 8.68182303e+00 1.84372813e+00 2.48598783e-01\n",
" 2.48598783e-01 1.58777881e-01 1.58777881e-01 1.40714658e-01\n",
" 1.40714658e-01 2.46739114e-01 6.76452668e-02 3.07196343e-02\n",
" 7.40902925e-02 8.04193656e-02 6.65598227e-02 1.37648042e-02\n",
" 6.27354038e-02 1.62255736e-02 2.48869339e-01 1.59160766e-01\n",
" 2.22783653e-01 1.56115891e-02 4.53244134e-02 7.19548342e-02\n",
" 4.93235663e-02 1.00494933e-02 1.76248055e-01 2.41340968e-02]\n",
" [2.39320470e+02 1.17372263e+01 2.66811055e+00 2.45021241e-01\n",
" 2.45021241e-01 2.05886385e-01 2.05886385e-01 2.31609688e-01\n",
" 2.31609688e-01 7.81456960e-02 8.04233988e-02 7.07549376e-02\n",
" 8.76099389e-02 4.66559597e-02 2.45550932e-01 4.06461062e-02\n",
" 5.06139054e-02 6.88829363e-02 5.64778887e-02 6.41603875e-02\n",
" 1.12356824e-01 2.15647605e-02 5.79302092e-02 9.72373742e-02\n",
" 1.49524559e-01 1.19810863e-02 2.42509538e-01 2.47155575e-02]]\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# **Logistic Regression**"
],
"metadata": {
"id": "AH4oN6HEQPYG"
}
},
{
"cell_type": "code",
"source": [
"import pandas as pd\n",
"\n",
"# Load the uploaded CSV files\n",
"train_file_path = 'Train.csv'\n",
"test_file_path = 'Test.csv'\n",
"\n",
"# Reading the train and test datasets\n",
"train_data = pd.read_csv(train_file_path)\n",
"test_data = pd.read_csv(test_file_path)\n",
"\n",
"# Display the first few rows of the datasets to understand their structure\n",
"train_data.head(), test_data.head()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "o4-9p1ClkfUD",
"outputId": "fd23d9ba-42c7-4b56-fc38-6bf029b4700c"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"( ID Gender Ever_Married Age Graduated Profession Work_Experience \\\n",
" 0 462809 Male No 22 No Healthcare 1.0 \n",
" 1 462643 Female Yes 38 Yes Engineer NaN \n",
" 2 466315 Female Yes 67 Yes Engineer 1.0 \n",
" 3 461735 Male Yes 67 Yes Lawyer 0.0 \n",
" 4 462669 Female Yes 40 Yes Entertainment NaN \n",
" \n",
" Spending_Score Family_Size Var_1 Segmentation \n",
" 0 Low 4.0 Cat_4 D \n",
" 1 Average 3.0 Cat_4 A \n",
" 2 Low 1.0 Cat_6 B \n",
" 3 High 2.0 Cat_6 B \n",
" 4 High 6.0 Cat_6 A ,\n",
" ID Gender Ever_Married Age Graduated Profession Work_Experience \\\n",
" 0 458989 Female Yes 36 Yes Engineer 0.0 \n",
" 1 458994 Male Yes 37 Yes Healthcare 8.0 \n",
" 2 458996 Female Yes 69 No NaN 0.0 \n",
" 3 459000 Male Yes 59 No Executive 11.0 \n",
" 4 459001 Female No 19 No Marketing NaN \n",
" \n",
" Spending_Score Family_Size Var_1 \n",
" 0 Low 1.0 Cat_6 \n",
" 1 Average 4.0 Cat_6 \n",
" 2 Low 1.0 Cat_6 \n",
" 3 High 2.0 Cat_6 \n",
" 4 Low 4.0 Cat_6 )"
]
},
"metadata": {},
"execution_count": 21
}
]
},
{
"cell_type": "code",
"source": [
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import LabelEncoder, StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.metrics import accuracy_score, classification_report\n",
"\n",
"# Preprocessing the data\n",
"def preprocess_data(data, is_train=True):\n",
" # Dropping ID column as it's not relevant\n",
" data = data.drop(columns=['ID'])\n",
"\n",
" # Handling missing values using SimpleImputer\n",
" imputer = SimpleImputer(strategy='most_frequent')\n",
" data[['Work_Experience', 'Family_Size']] = imputer.fit_transform(data[['Work_Experience', 'Family_Size']])\n",
"\n",
" # Encoding categorical variables\n",
" encoder = LabelEncoder()\n",
" data['Gender'] = encoder.fit_transform(data['Gender'])\n",
" data['Ever_Married'] = encoder.fit_transform(data['Ever_Married'])\n",
" data['Graduated'] = encoder.fit_transform(data['Graduated'])\n",
" data['Profession'] = encoder.fit_transform(data['Profession'].astype(str))\n",
" data['Spending_Score'] = encoder.fit_transform(data['Spending_Score'])\n",
" data['Var_1'] = encoder.fit_transform(data['Var_1'].astype(str))\n",
"\n",
" if is_train:\n",
" # Encode the target variable (Segmentation)\n",
" data['Segmentation'] = encoder.fit_transform(data['Segmentation'])\n",
"\n",
" return data\n",
"\n",
"# Preprocess train and test datasets\n",
"train_data_processed = preprocess_data(train_data)\n",
"test_data_processed = preprocess_data(test_data, is_train=False)\n",
"\n",
"# Splitting features and target variable for the train dataset\n",
"X_train = train_data_processed.drop(columns=['Segmentation'])\n",
"y_train = train_data_processed['Segmentation']\n",
"\n",
"# Standardizing the features\n",
"scaler = StandardScaler()\n",
"X_train_scaled = scaler.fit_transform(X_train)\n",
"X_test_scaled = scaler.transform(test_data_processed)\n",
"\n",
"# Applying Logistic Regression\n",
"log_reg = LogisticRegression(max_iter=500)\n",
"log_reg.fit(X_train_scaled, y_train)\n",
"\n",
"# Predicting on the train and test data\n",
"y_train_pred = log_reg.predict(X_train_scaled)\n",
"y_test_pred = log_reg.predict(X_test_scaled)\n",
"\n",
"# Calculating accuracy and classification report for train and test data\n",
"train_accuracy = accuracy_score(y_train, y_train_pred)\n",
"test_accuracy = accuracy_score(y_train[:len(y_test_pred)], y_test_pred)\n",
"\n",
"train_report = classification_report(y_train, y_train_pred, target_names=['A', 'B', 'C', 'D'])\n",
"test_report = classification_report(y_train[:len(y_test_pred)], y_test_pred, target_names=['A', 'B', 'C', 'D'])\n",
"\n",
"# Printing the results\n",
"print(f\"Train Accuracy: {train_accuracy}\")\n",
"print(f\"Test Accuracy: {test_accuracy}\")\n",
"print(\"\\nTrain Classification Report:\")\n",
"print(train_report)\n",
"print(\"\\nTest Classification Report:\")\n",
"print(test_report)\n",
"\n",
"\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "vZDYhCEbaDoS",
"outputId": "1409200c-897c-4f61-87ac-159e2a48a192"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Train Accuracy: 0.4970252850768468\n",
"Test Accuracy: 0.26303768557289686\n",
"\n",
"Train Classification Report:\n",
" precision recall f1-score support\n",
"\n",
" A 0.41 0.44 0.42 1972\n",
" B 0.36 0.14 0.20 1858\n",
" C 0.48 0.61 0.54 1970\n",
" D 0.61 0.74 0.67 2268\n",
"\n",
" accuracy 0.50 8068\n",
" macro avg 0.47 0.48 0.46 8068\n",
"weighted avg 0.47 0.50 0.47 8068\n",
"\n",
"\n",
"Test Classification Report:\n",
" precision recall f1-score support\n",
"\n",
" A 0.27 0.28 0.28 670\n",
" B 0.23 0.10 0.14 585\n",
" C 0.24 0.31 0.27 629\n",
" D 0.28 0.34 0.31 743\n",
"\n",
" accuracy 0.26 2627\n",
" macro avg 0.26 0.26 0.25 2627\n",
"weighted avg 0.26 0.26 0.25 2627\n",
"\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"Observations:\n",
"\n",
"- Accuracy(Training Data): 49.74%\n",
" - Overall: The model performs better for Class C and D, similar to the Naive Bayes classifier. Class B remains the most challenging to predict accurately.\n",
"-Accuracy(Test Data): 26.30%\n",
"\n",
"Comparison:\n",
"Logistic Regression:\n",
"\n",
"- Training Accuracy: 49.74%\n",
"- Test Accuracy: 26.30%\n",
"- Test F1-Score (Class C): 0.27\n",
"- Test F1-Score (Class D): 0.31\n",
"\n",
"Gaussian Naive Bayes:\n",
"\n",
"- Training Accuracy: 48.72%\n",
"- Test Accuracy: 26.87%\n",
"- Test F1-Score (Class C): 0.29\n",
"- Test F1-Score (Class D): 0.32\n",
"\n",
"Conclusion:\n",
"\n",
"- Gaussian Naive Bayes outperforms Logistic Regression for this dataset, especially in terms of test accuracy and the F1-scores for key classes (C and D).\n",
"\n",
"- Given the test set accuracy and F1-scores, Gaussian Naive Bayes is the better model for this particular problem."
],
"metadata": {
"id": "EYTcNlVzQjqh"
}
},
{
"cell_type": "markdown",
"source": [
"**Class_Weight(Balanced)**"
],
"metadata": {
"id": "RZUklXwGO63W"
}
},
{
"cell_type": "code",
"source": [
"# Re-loading the datasets and preprocessing\n",
"import pandas as pd\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import LabelEncoder, StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.metrics import accuracy_score, classification_report\n",
"\n",
"# Paths to the files\n",
"train_file_path = 'Train.csv'\n",
"test_file_path = 'Test.csv'\n",
"\n",
"# Reading the train and test datasets\n",
"train_data = pd.read_csv(train_file_path)\n",
"test_data = pd.read_csv(test_file_path)\n",
"\n",
"# Preprocessing the data\n",
"def preprocess_data(data, is_train=True):\n",
" # Dropping ID column as it's not relevant\n",
" data = data.drop(columns=['ID'])\n",
"\n",
" # Handling missing values using SimpleImputer\n",
" imputer = SimpleImputer(strategy='most_frequent')\n",
" data[['Work_Experience', 'Family_Size']] = imputer.fit_transform(data[['Work_Experience', 'Family_Size']])\n",
"\n",
" # Encoding categorical variables\n",
" encoder = LabelEncoder()\n",
" data['Gender'] = encoder.fit_transform(data['Gender'])\n",
" data['Ever_Married'] = encoder.fit_transform(data['Ever_Married'])\n",
" data['Graduated'] = encoder.fit_transform(data['Graduated'])\n",
" data['Profession'] = encoder.fit_transform(data['Profession'].astype(str))\n",
" data['Spending_Score'] = encoder.fit_transform(data['Spending_Score'])\n",
" data['Var_1'] = encoder.fit_transform(data['Var_1'].astype(str))\n",
"\n",
" if is_train:\n",
" # Encode the target variable (Segmentation)\n",
" data['Segmentation'] = encoder.fit_transform(data['Segmentation'])\n",
"\n",
" return data\n",
"\n",
"# Preprocess train and test datasets\n",
"train_data_processed = preprocess_data(train_data)\n",
"test_data_processed = preprocess_data(test_data, is_train=False)\n",
"\n",
"# Splitting features and target variable for the train dataset\n",
"X_train = train_data_processed.drop(columns=['Segmentation'])\n",
"y_train = train_data_processed['Segmentation']\n",
"\n",
"# Standardizing the features\n",
"scaler = StandardScaler()\n",
"X_train_scaled = scaler.fit_transform(X_train)\n",
"X_test_scaled = scaler.transform(test_data_processed)\n",
"\n",
"# Rebuilding Logistic Regression models with different configurations\n",
"\n",
"# 1. Logistic Regression with class_weight='balanced'\n",
"log_reg_class_weight = LogisticRegression(max_iter=500, class_weight='balanced')\n",
"log_reg_class_weight.fit(X_train_scaled, y_train)\n",
"\n",
"\n",
"\n",
"# Predictions for both models on the train and test datasets\n",
"\n",
"# Class Weight = Balanced\n",
"y_train_pred_class_weight = log_reg_class_weight.predict(X_train_scaled)\n",
"y_test_pred_class_weight = log_reg_class_weight.predict(X_test_scaled)\n",
"\n",
"\n",
"\n",
"# Calculating accuracy and classification reports for both models\n",
"\n",
"# Class Weight = Balanced\n",
"train_accuracy_class_weight = accuracy_score(y_train, y_train_pred_class_weight)\n",
"test_accuracy_class_weight = accuracy_score(y_train[:len(y_test_pred_class_weight)], y_test_pred_class_weight)\n",
"train_report_class_weight = classification_report(y_train, y_train_pred_class_weight, target_names=['A', 'B', 'C', 'D'])\n",
"test_report_class_weight = classification_report(y_train[:len(y_test_pred_class_weight)], y_test_pred_class_weight, target_names=['A', 'B', 'C', 'D'])\n",
"\n",
"\n",
"\n",
"# Printing the results\n",
"\n",
"print(f\"Train Accuracy Class Weight: {train_accuracy_class_weight}\")\n",
"print(f\"Test Accuracy Class Weight: {test_accuracy_class_weight}\")\n",
"\n",
"print(\"\\nTrain Classification Report Class Weight:\")\n",
"print(train_report_class_weight)\n",
"print(\"\\nTest Classification Report Class Weight:\")\n",
"print(test_report_class_weight)\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "TloB1IH2O2Ts",
"outputId": "4c7662e4-53d7-487a-fc8d-f84c99748c0f"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Train Accuracy Class Weight: 0.5012394645513139\n",
"Test Accuracy Class Weight: 0.26417967263037684\n",
"\n",
"Train Classification Report Class Weight:\n",
" precision recall f1-score support\n",
"\n",
" A 0.42 0.45 0.44 1972\n",
" B 0.36 0.21 0.26 1858\n",
" C 0.49 0.58 0.53 1970\n",
" D 0.64 0.72 0.67 2268\n",
"\n",
" accuracy 0.50 8068\n",
" macro avg 0.48 0.49 0.48 8068\n",
"weighted avg 0.48 0.50 0.49 8068\n",
"\n",
"\n",
"Test Classification Report Class Weight:\n",
" precision recall f1-score support\n",
"\n",
" A 0.28 0.29 0.28 670\n",
" B 0.22 0.14 0.17 585\n",
" C 0.25 0.29 0.27 629\n",
" D 0.29 0.32 0.30 743\n",
"\n",
" accuracy 0.26 2627\n",
" macro avg 0.26 0.26 0.26 2627\n",
"weighted avg 0.26 0.26 0.26 2627\n",
"\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"Observations:\n",
"\n",
"Comparing the Results:\n",
"Training Accuracy:\n",
"- Without class_weight and multi_class: 49.74%\n",
"- With class_weight='balanced': 50.12%\n",
"- With multi_class='ovr': 48.86%\n",
"\n",
"Analysis:\n",
"\n",
"All models perform similarly on training accuracy, with class_weight='balanced' achieving the highest score, but the differences are marginal.\n",
"\n",
"Test Accuracy:\n",
"- Without class_weight and multi_class: 26.30%\n",
"- With class_weight='balanced': 26.42%\n",
"\n",
"\n",
"Analysis:\n",
"\n",
"The test accuracy is very similar across all models. However, it's important to remember that accuracy is not always the best measure for imbalanced data (which is the case here).\n",
"\n",
"Minority Class (Class B) Performance:\n",
"Recall on Test Data for Class B:\n",
"- Without class_weight and multi_class: 0.10\n",
"\n",
"\n",
"Analysis:\n",
"\n",
"- The class_weight='balanced' model clearly does better in improving recall for the minority class B\n",
"\n",
"F1-Score on Test Data for Class B:\n",
"- Without class_weight and multi_class: 0.13\n",
"- With class_weight='balanced': 0.17\n",
"\n",
"Analysis:\n",
"\n",
"Class-weight balancing improves the F1-score for B compared to unweighted\n",
"\n",
"Conclusion:\n",
"- Introducing class weights helps slightly improve recall for the minority class (Class B), but the overall accuracy and F1-scores for other classes (A, C, D) remain similar.\n",
"- F1-score remains a better evaluation metric compared to accuracy when working with imbalanced data. Accuracy might suggest the model performs well, but it can mask poor performance on minority classes."
],
"metadata": {
"id": "UA5VLX3bUQML"
}
},
{
"cell_type": "markdown",
"source": [
"**Multiclass=ovr**"
],
"metadata": {
"id": "4YK-BQAx7Pro"
}
},
{
"cell_type": "code",
"source": [
"# Logistic Regression with multi_class='ovr'\n",
"log_reg_multiclass_ovr = LogisticRegression(max_iter=500, multi_class='ovr')\n",
"log_reg_multiclass_ovr.fit(X_train_scaled, y_train)\n",
"\n",
"# Multi-class = OvR\n",
"y_train_pred_multiclass_ovr = log_reg_multiclass_ovr.predict(X_train_scaled)\n",
"y_test_pred_multiclass_ovr = log_reg_multiclass_ovr.predict(X_test_scaled)\n",
"\n",
"# Multi-class = OvR\n",
"train_accuracy_multiclass_ovr = accuracy_score(y_train, y_train_pred_multiclass_ovr)\n",
"test_accuracy_multiclass_ovr = accuracy_score(y_train[:len(y_test_pred_multiclass_ovr)], y_test_pred_multiclass_ovr)\n",
"train_report_multiclass_ovr = classification_report(y_train, y_train_pred_multiclass_ovr, target_names=['A', 'B', 'C', 'D'])\n",
"test_report_multiclass_ovr = classification_report(y_train[:len(y_test_pred_multiclass_ovr)], y_test_pred_multiclass_ovr, target_names=['A', 'B', 'C', 'D'])\n",
"\n",
"print(f\"Train Accuracy Multi-class_OvR: {train_accuracy_multiclass_ovr}\")\n",
"print(f\"Test Accuracy Multi-class_OvR: {test_accuracy_multiclass_ovr}\")\n",
"print(\"\\nTrain Classification Report Multi-class_OvR:\")\n",
"print(train_report_multiclass_ovr)\n",
"print(\"\\nTest Classification Report Multi-class_OvR:\")\n",
"print(test_report_multiclass_ovr)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6QUTnYcR65DN",
"outputId": "929be307-5da8-40d6-9645-2a2677f685ef"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Train Accuracy Multi-class_OvR: 0.4885969261279127\n",
"Test Accuracy Multi-class_OvR: 0.2554244385230301\n",
"\n",
"Train Classification Report Multi-class_OvR:\n",
" precision recall f1-score support\n",
"\n",
" A 0.40 0.42 0.41 1972\n",
" B 0.33 0.07 0.12 1858\n",
" C 0.47 0.63 0.54 1970\n",
" D 0.59 0.76 0.66 2268\n",
"\n",
" accuracy 0.49 8068\n",
" macro avg 0.45 0.47 0.43 8068\n",
"weighted avg 0.45 0.49 0.45 8068\n",
"\n",
"\n",
"Test Classification Report Multi-class_OvR:\n",
" precision recall f1-score support\n",
"\n",
" A 0.25 0.26 0.25 670\n",
" B 0.20 0.05 0.08 585\n",
" C 0.24 0.33 0.28 629\n",
" D 0.28 0.36 0.31 743\n",
"\n",
" accuracy 0.26 2627\n",
" macro avg 0.24 0.25 0.23 2627\n",
"weighted avg 0.25 0.26 0.24 2627\n",
"\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py:1256: FutureWarning: 'multi_class' was deprecated in version 1.5 and will be removed in 1.7. Use OneVsRestClassifier(LogisticRegression(..)) instead. Leave it to its default value to avoid this warning.\n",
" warnings.warn(\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Logistic Regression with multi_class='ovr'\n",
"log_reg_multiclass_mn = LogisticRegression(max_iter=500, multi_class='multinomial', solver='lbfgs',)\n",
"log_reg_multiclass_mn.fit(X_train_scaled, y_train)\n",
"\n",
"# Multi-class = OvR\n",
"y_train_pred_multiclass_mn = log_reg_multiclass_mn.predict(X_train_scaled)\n",
"y_test_pred_multiclass_mn = log_reg_multiclass_mn.predict(X_test_scaled)\n",
"\n",
"# Multi-class = OvR\n",
"train_accuracy_multiclass_mn = accuracy_score(y_train, y_train_pred_multiclass_mn)\n",
"test_accuracy_multiclass_mn = accuracy_score(y_train[:len(y_test_pred_multiclass_mn)], y_test_pred_multiclass_mn)\n",
"train_report_multiclass_mn = classification_report(y_train, y_train_pred_multiclass_ovr, target_names=['A', 'B', 'C', 'D'])\n",
"test_report_multiclass_mn = classification_report(y_train[:len(y_test_pred_multiclass_mn)], y_test_pred_multiclass_mn, target_names=['A', 'B', 'C', 'D'])\n",
"\n",
"print(f\"Train Accuracy Multi-class_mn: {train_accuracy_multiclass_mn}\")\n",
"print(f\"Test Accuracy Multi-class_mn: {test_accuracy_multiclass_mn}\")\n",
"print(\"\\nTrain Classification Report Multi-class_mn:\")\n",
"print(train_report_multiclass_mn)\n",
"print(\"\\nTest Classification Report Multi-class_mn:\")\n",
"print(test_report_multiclass_mn)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "KpcU989c7VJ4",
"outputId": "538d3b62-c89b-4aae-f20f-8b197204668b"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Train Accuracy Multi-class_mn: 0.4970252850768468\n",
"Test Accuracy Multi-class_mn: 0.26303768557289686\n",
"\n",
"Train Classification Report Multi-class_mn:\n",
" precision recall f1-score support\n",
"\n",
" A 0.40 0.42 0.41 1972\n",
" B 0.33 0.07 0.12 1858\n",
" C 0.47 0.63 0.54 1970\n",
" D 0.59 0.76 0.66 2268\n",
"\n",
" accuracy 0.49 8068\n",
" macro avg 0.45 0.47 0.43 8068\n",
"weighted avg 0.45 0.49 0.45 8068\n",
"\n",
"\n",
"Test Classification Report Multi-class_mn:\n",
" precision recall f1-score support\n",
"\n",
" A 0.27 0.28 0.28 670\n",
" B 0.23 0.10 0.14 585\n",
" C 0.24 0.31 0.27 629\n",
" D 0.28 0.34 0.31 743\n",
"\n",
" accuracy 0.26 2627\n",
" macro avg 0.26 0.26 0.25 2627\n",
"weighted avg 0.26 0.26 0.25 2627\n",
"\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py:1247: FutureWarning: 'multi_class' was deprecated in version 1.5 and will be removed in 1.7. From then on, it will always use 'multinomial'. Leave it to its default value to avoid this warning.\n",
" warnings.warn(\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"**Cost of Misclassification with Logistic Regression**"
],
"metadata": {
"id": "zrYssp7UqyOk"
}
},
{
"cell_type": "markdown",
"source": [
"The dataset contains customer segmentation data with four classes (A, B, C, D). Each customer is placed in one of these segments based on various features (e.g., age, profession, spending score, etc.).\n",
"\n",
"Segments and Their Priorities (Business Scenario):\n",
"- Segment A: High-value customers (misclassifying them is costly).\n",
"- Segment B: Low-value customers (misclassifications are less costly).\n",
"- Segment C: Potential long-term customers (medium misclassification cost).\n",
"- Segment D: Regular customers (medium misclassification cost).\n",
"\n",
"Assigned Costs of Misclassification:\n",
"- Misclassifying a Segment A customer (e.g., as B, C, or D): Cost = 5.\n",
"- Misclassifying a Segment C or D customer: Cost = 2.\n",
"- Misclassifying a Segment B customer: Cost = 1 (lower cost since they are low-priority customers)."
],
"metadata": {
"id": "e2Lmb3OurCKN"
}
},
{
"cell_type": "code",
"source": [
"import pandas as pd\n",
"from sklearn.preprocessing import StandardScaler, LabelEncoder\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"# Load your dataset\n",
"# train_data = pd.read_csv('path_to_train.csv')\n",
"\n",
"# Define features (X) and target (y)\n",
"X = train_data.drop(columns=['Segmentation', 'ID']) # Exclude target and ID column\n",
"y = train_data['Segmentation']\n",
"\n",
"# Encode categorical variables and target variable\n",
"encoder = LabelEncoder()\n",
"y = encoder.fit_transform(y)\n",
"\n",
"# Scale the numeric features\n",
"scaler = StandardScaler()\n",
"X_scaled = scaler.fit_transform(X)\n",
"\n",
"# Split data into training and test sets\n",
"X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)"
],
"metadata": {
"id": "_UmwLAafCRbA"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Re-loading the datasets and preprocessing\n",
"import pandas as pd\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import LabelEncoder, StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.metrics import accuracy_score, classification_report, confusion_matrix\n",
"\n",
"# Paths to the files\n",
"train_file_path = 'Train.csv'\n",
"test_file_path = 'Test.csv'\n",
"\n",
"# Reading the train and test datasets\n",
"train_data = pd.read_csv(train_file_path)\n",
"test_data = pd.read_csv(test_file_path)\n",
"\n",
"# Preprocessing the data\n",
"def preprocess_data(data, is_train=True):\n",
" # Dropping ID column as it's not relevant\n",
" data = data.drop(columns=['ID'])\n",
"\n",
" # Handling missing values using SimpleImputer\n",
" imputer = SimpleImputer(strategy='most_frequent')\n",
" data[['Work_Experience', 'Family_Size']] = imputer.fit_transform(data[['Work_Experience', 'Family_Size']])\n",
"\n",
" # Encoding categorical variables\n",
" encoder = LabelEncoder()\n",
" data['Gender'] = encoder.fit_transform(data['Gender'])\n",
" data['Ever_Married'] = encoder.fit_transform(data['Ever_Married'])\n",
" data['Graduated'] = encoder.fit_transform(data['Graduated'])\n",
" data['Profession'] = encoder.fit_transform(data['Profession'].astype(str))\n",
" data['Spending_Score'] = encoder.fit_transform(data['Spending_Score'])\n",
" data['Var_1'] = encoder.fit_transform(data['Var_1'].astype(str))\n",
"\n",
" if is_train:\n",
" # Encode the target variable (Segmentation)\n",
" data['Segmentation'] = encoder.fit_transform(data['Segmentation'])\n",
"\n",
" return data\n",
"\n",
"# Preprocess train and test datasets\n",
"train_data_processed = preprocess_data(train_data)\n",
"test_data_processed = preprocess_data(test_data, is_train=False)\n",
"\n",
"# Splitting features and target variable for the train dataset\n",
"X_train = train_data_processed.drop(columns=['Segmentation'])\n",
"y_train = train_data_processed['Segmentation']\n",
"\n",
"# Standardizing the features\n",
"scaler = StandardScaler()\n",
"X_train_scaled = scaler.fit_transform(X_train)\n",
"X_test_scaled = scaler.transform(test_data_processed)\n",
"\n",
"\n",
"# Train Logistic Regression with class_weight and ovr\n",
"log_reg_ovr_weighted = LogisticRegression(class_weight='balanced', multi_class='ovr', solver='lbfgs', max_iter=500)\n",
"log_reg_ovr_weighted.fit(X_train, y_train)\n",
"\n",
"# Predict on the train data (using X_train_scaled)\n",
"y_pred_ovr_weighted = log_reg_ovr_weighted.predict(X_train_scaled)\n",
"\n",
"# Calculate accuracy and classification report\n",
"accuracy_ovr_weighted = accuracy_score(y_train, y_pred_ovr_weighted)\n",
"report_ovr_weighted = classification_report(y_train, y_pred_ovr_weighted, target_names=['A', 'B', 'C', 'D'])\n",
"print(\"Accuracy (OvR with class_weight='balanced'):\", accuracy_ovr_weighted)\n",
"print(\"Classification Report (OvR with class_weight='balanced'):\")\n",
"print(report_ovr_weighted)\n",
"\n",
"# Confusion Matrix\n",
"cm_ovr_weighted = confusion_matrix(y_train, y_pred_ovr_weighted)\n",
"print(\"Confusion Matrix (OvR with class_weight='balanced'):\")\n",
"print(cm_ovr_weighted)\n",
"\n",
"# Train Logistic Regression with class_weight and multinomial\n",
"log_reg_multinomial_weighted = LogisticRegression(class_weight='balanced', multi_class='multinomial', solver='lbfgs', max_iter=500)\n",
"log_reg_multinomial_weighted.fit(X_train, y_train)\n",
"\n",
"# Predict on the train data (using X_train_scaled)\n",
"y_pred_multinomial_weighted = log_reg_multinomial_weighted.predict(X_train_scaled)\n",
"\n",
"# Calculate accuracy and classification report\n",
"accuracy_multinomial_weighted = accuracy_score(y_train, y_pred_multinomial_weighted)\n",
"report_multinomial_weighted = classification_report(y_train, y_pred_multinomial_weighted, target_names=['A', 'B', 'C', 'D'])\n",
"print(\"Accuracy (Multinomial with class_weight='balanced'):\", accuracy_multinomial_weighted)\n",
"print(\"Classification Report (Multinomial with class_weight='balanced'):\")\n",
"print(report_multinomial_weighted)\n",
"\n",
"# Confusion Matrix\n",
"cm_multinomial_weighted = confusion_matrix(y_train, y_pred_multinomial_weighted)\n",
"print(\"Confusion Matrix (Multinomial with class_weight='balanced'):\")\n",
"print(cm_multinomial_weighted)\n",
"\n",
"import numpy as np\n",
"\n",
"# Define the cost matrix (higher values represent higher costs of misclassification)\n",
"cost_matrix = np.array([[0, 5, 5, 5], # Misclassifying A\n",
" [1, 0, 1, 1], # Misclassifying B\n",
" [2, 2, 0, 2], # Misclassifying C\n",
" [2, 2, 2, 0]]) # Misclassifying D\n",
"\n",
"# Calculate the total misclassification cost for both models\n",
"misclassification_cost_ovr = np.sum(cm_ovr_weighted * cost_matrix)\n",
"misclassification_cost_multinomial = np.sum(cm_multinomial_weighted * cost_matrix)\n",
"\n",
"print(f\"Total Misclassification Cost (OvR): {misclassification_cost_ovr}\")\n",
"print(f\"Total Misclassification Cost (Multinomial): {misclassification_cost_multinomial}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "lH6jLKbs6KuI",
"outputId": "7da52803-4e83-43b8-b75c-0679eede0140"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py:1256: FutureWarning: 'multi_class' was deprecated in version 1.5 and will be removed in 1.7. Use OneVsRestClassifier(LogisticRegression(..)) instead. Leave it to its default value to avoid this warning.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/sklearn/base.py:493: UserWarning: X does not have valid feature names, but LogisticRegression was fitted with feature names\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py:1247: FutureWarning: 'multi_class' was deprecated in version 1.5 and will be removed in 1.7. From then on, it will always use 'multinomial'. Leave it to its default value to avoid this warning.\n",
" warnings.warn(\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Accuracy (OvR with class_weight='balanced'): 0.40865146256817053\n",
"Classification Report (OvR with class_weight='balanced'):\n",
" precision recall f1-score support\n",
"\n",
" A 0.34 0.41 0.37 1972\n",
" B 0.31 0.32 0.31 1858\n",
" C 0.55 0.07 0.13 1970\n",
" D 0.50 0.78 0.61 2268\n",
"\n",
" accuracy 0.41 8068\n",
" macro avg 0.42 0.39 0.35 8068\n",
"weighted avg 0.43 0.41 0.36 8068\n",
"\n",
"Confusion Matrix (OvR with class_weight='balanced'):\n",
"[[ 801 308 29 834]\n",
" [ 689 593 69 507]\n",
" [ 469 945 142 414]\n",
" [ 406 82 19 1761]]\n",
"Accuracy (Multinomial with class_weight='balanced'): 0.40282597917699553\n",
"Classification Report (Multinomial with class_weight='balanced'):\n",
" precision recall f1-score support\n",
"\n",
" A 0.32 0.60 0.42 1972\n",
" B 0.30 0.26 0.28 1858\n",
" C 0.58 0.05 0.10 1970\n",
" D 0.58 0.65 0.61 2268\n",
"\n",
" accuracy 0.40 8068\n",
" macro avg 0.44 0.39 0.35 8068\n",
"weighted avg 0.45 0.40 0.36 8068\n",
"\n",
"Confusion Matrix (Multinomial with class_weight='balanced'):\n",
"[[1185 239 19 529]\n",
" [1039 489 40 290]\n",
" [ 763 839 104 264]\n",
" [ 715 64 17 1472]]\n",
"Total Misclassification Cost (OvR): 11790\n",
"Total Misclassification Cost (Multinomial): 10628\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
"\n",
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
"Please also refer to the documentation for alternative solver options:\n",
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
" n_iter_i = _check_optimize_result(\n",
"/usr/local/lib/python3.10/dist-packages/sklearn/base.py:493: UserWarning: X does not have valid feature names, but LogisticRegression was fitted with feature names\n",
" warnings.warn(\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"Observations from the result:\n",
"\n",
"Comparison Between OvR and Multinomial Models:\n",
"\n",
"1. Accuracy:\n",
"\n",
"Both models have almost identical accuracy (40.87% for OvR and 40.28% for Multinomial). This indicates that overall accuracy is not the key differentiator for these models.\n",
"\n",
"2. Class-Specific Performance:\n",
"\n",
"Class A: The Multinomial model performs better in terms of recall (0.60 vs. 0.41 in OvR), meaning it can identify more class A customers correctly.\n",
"Class B: Both models struggle with class B, but OvR has slightly better recall (0.32 vs. 0.26), though precision is similar.\n",
"Class C: Both models fail to classify class C correctly, with very low recall (0.07 for OvR and 0.05 for Multinomial).\n",
"Class D: The OvR model performs better in class D recall (0.78 vs. 0.65), indicating that it is more successful at identifying class D customers, though precision is similar.\n",
"\n",
"3. Misclassification Cost:\n",
"\n",
"OvR Total Misclassification Cost: 11790\n",
"Multinomial Total Misclassification Cost: 10628\n",
"The Multinomial model has a lower misclassification cost (10628 vs. 11790), indicating that it does a better job at minimizing the most costly errors. Specifically, the improvement in recall for class A (high-cost class) contributes to reducing the overall misclassification cost.\n",
"\n",
"Conclusion:\n",
"Strengths of Multinomial Model:\n",
"\n",
"- Lower misclassification cost: The Multinomial model has a lower total cost of misclassification, making it more cost-effective in real-world scenarios.\n",
"Better recall for class A: The Multinomial model is better at identifying class A customers, which is important if class A represents high-value customers.\n",
"\n",
"Strengths of OvR Model:\n",
"\n",
"- Better recall for class D: The OvR model performs better in class D recall, meaning it is better at identifying regular or majority customers.\n",
"Overall: The Multinomial Logistic Regression model with class_weight='balanced' is slightly better in this case due to its lower misclassification cost and better recall for high-value class A customers. However, both models struggle significantly with class B and C, and neither is a clear winner in terms of accuracy."
],
"metadata": {
"id": "5rYD8yObFttS"
}
}
]
}