Skip to main content

Decision Trees Example

Let’s walk through the steps to build a simple decision tree using the Gini Index for classification. We'll work with an example dataset to classify whether a person will buy a product based on Age and Income.


Example Dataset

IDAgeIncomeWill Buy (Target)
12540000No
24580000Yes
33560000Yes
450120000Yes
52335000No

Step 1: Select the Best Split

Gini Index Formula

The Gini Index measures the impurity of a node. Lower values indicate better splits.

Gini=1_i=1kpi2Gini = 1 - \sum\_{i=1}^{k} p_i^2

Where:

  • ( p_i ) is the probability of samples belonging to class ( i ) at a node.
  • ( k ) is the total number of classes.

Calculate Gini Index for a Split

Assume we split the dataset based on the threshold ( age ≤ 30 ).

  • Left Node (Age ≤ 30):

    • Data: ID 1, ID 5
    • "Will Buy" = No (2 samples, 100%)
    • ( Gini = 1 - (1^2) = 0 )
  • Right Node (Age > 30):

    • Data: ID 2, ID 3, ID 4
    • "Will Buy" = Yes (3 samples, 100%)
    • ( Gini = 1 - (1^2) = 0 )

Weighted Gini for the split:

Gini_split=25(0)+35(0)=0Gini\_{split} = \frac{2}{5}(0) + \frac{3}{5}(0) = 0

This split perfectly separates the data, making it the best choice.


Step 2: Split the Dataset

Using ( age ≤ 30 ), divide the dataset into two groups:

  • Left Node (Age ≤ 30): ID 1, ID 5
  • Right Node (Age > 30): ID 2, ID 3, ID 4

Step 3: Repeat

Recursively apply the splitting process on the nodes until a stopping condition is met (e.g., max depth or minimum samples).


Step 4: Assign Leaf Values

  • Left Node: All samples are "No," so the leaf value is "No."
  • Right Node: All samples are "Yes," so the leaf value is "Yes."

What is Gini Index?

The Gini Index measures how often a randomly chosen element would be incorrectly classified if randomly labeled based on the distribution of classes in the node. It ranges from:

  • 0 (perfect purity): All elements in a node belong to one class.
  • 0.5 (maximum impurity): Classes are evenly distributed.

Code Implementation

Here’s the Python code to demonstrate this using sklearn:

from sklearn.tree import DecisionTreeClassifier, plot_tree
import pandas as pd
import matplotlib.pyplot as plt

# Example Dataset
data = {
"Age": [25, 45, 35, 50, 23],
"Income": [40000, 80000, 60000, 120000, 35000],
"Will_Buy": ["No", "Yes", "Yes", "Yes", "No"]
}

df = pd.DataFrame(data)

# Features (X) and Target (y)
X = df[["Age", "Income"]]
y = df["Will_Buy"]

# Decision Tree with Gini Index
clf = DecisionTreeClassifier(criterion="gini", max_depth=2, random_state=42)
clf.fit(X, y)

# Visualize the Decision Tree
plt.figure(figsize=(10, 6))
plot_tree(clf, filled=True, feature_names=["Age", "Income"], class_names=["No", "Yes"])
plt.show()

Key Points to Remember

  • The Gini Index helps decide the best split by minimizing impurity.
  • A split with a Gini Index of 0 indicates perfect separation of classes.
  • The process repeats recursively, assigning class labels or values to the leaf nodes.