Decision Trees Example
Let’s walk through the steps to build a simple decision tree using the Gini Index for classification. We'll work with an example dataset to classify whether a person will buy a product based on Age and Income.
Example Dataset
ID | Age | Income | Will Buy (Target) |
---|---|---|---|
1 | 25 | 40000 | No |
2 | 45 | 80000 | Yes |
3 | 35 | 60000 | Yes |
4 | 50 | 120000 | Yes |
5 | 23 | 35000 | No |
Step 1: Select the Best Split
Gini Index Formula
The Gini Index measures the impurity of a node. Lower values indicate better splits.
Where:
- ( p_i ) is the probability of samples belonging to class ( i ) at a node.
- ( k ) is the total number of classes.
Calculate Gini Index for a Split
Assume we split the dataset based on the threshold ( age ≤ 30 ).
-
Left Node (Age ≤ 30):
- Data: ID 1, ID 5
- "Will Buy" = No (2 samples, 100%)
- ( Gini = 1 - (1^2) = 0 )
-
Right Node (Age > 30):
- Data: ID 2, ID 3, ID 4
- "Will Buy" = Yes (3 samples, 100%)
- ( Gini = 1 - (1^2) = 0 )
Weighted Gini for the split:
This split perfectly separates the data, making it the best choice.
Step 2: Split the Dataset
Using ( age ≤ 30 ), divide the dataset into two groups:
- Left Node (Age ≤ 30): ID 1, ID 5
- Right Node (Age > 30): ID 2, ID 3, ID 4
Step 3: Repeat
Recursively apply the splitting process on the nodes until a stopping condition is met (e.g., max depth or minimum samples).
Step 4: Assign Leaf Values
- Left Node: All samples are "No," so the leaf value is "No."
- Right Node: All samples are "Yes," so the leaf value is "Yes."
What is Gini Index?
The Gini Index measures how often a randomly chosen element would be incorrectly classified if randomly labeled based on the distribution of classes in the node. It ranges from:
- 0 (perfect purity): All elements in a node belong to one class.
- 0.5 (maximum impurity): Classes are evenly distributed.
Code Implementation
Here’s the Python code to demonstrate this using sklearn
:
from sklearn.tree import DecisionTreeClassifier, plot_tree
import pandas as pd
import matplotlib.pyplot as plt
# Example Dataset
data = {
"Age": [25, 45, 35, 50, 23],
"Income": [40000, 80000, 60000, 120000, 35000],
"Will_Buy": ["No", "Yes", "Yes", "Yes", "No"]
}
df = pd.DataFrame(data)
# Features (X) and Target (y)
X = df[["Age", "Income"]]
y = df["Will_Buy"]
# Decision Tree with Gini Index
clf = DecisionTreeClassifier(criterion="gini", max_depth=2, random_state=42)
clf.fit(X, y)
# Visualize the Decision Tree
plt.figure(figsize=(10, 6))
plot_tree(clf, filled=True, feature_names=["Age", "Income"], class_names=["No", "Yes"])
plt.show()
Key Points to Remember
- The Gini Index helps decide the best split by minimizing impurity.
- A split with a Gini Index of 0 indicates perfect separation of classes.
- The process repeats recursively, assigning class labels or values to the leaf nodes.