Decision trees#

We explore a class of algorithms based on decision trees. Decision trees are extremely intuitive at their core. They encode a series of “if” and “else” choices, similar to how a person makes a decision. The data determines which questions to ask and how to proceed for each answer.

For example, to create a guide for identifying an animal found in nature, you might ask the following series of questions:

  • Is the animal bigger or smaller than a meter long?

    • bigger: does the animal have horns?

      • yes: are the horns longer than ten centimeters?

      • no: does the animal wear a collar?

    • smaller: does the animal have two or four legs?

      • two: does the animal have wings?

      • four: does the animal have a bushy tail?

And so on. This binary splitting of questions forms the essence of a decision tree.

Tree-based models offer several key benefits. First, they require little preprocessing of the data. They work with variables of different types (continuous and discrete) and remain invariant to feature scaling.

Tree-based models are also “nonparametric”, meaning they do not have a fixed set of parameters to learn. Instead, a tree model becomes more flexible with more data. In other words, the number of free parameters grows with the number of samples and is not fixed like in linear models.

Decision tree for classification#

Let’s get some intuitions on how a decision tree works on a very simple dataset.

# When using JupyterLite, uncomment and install the `skrub` package.
%pip install skrub
import matplotlib.pyplot as plt
import skrub

skrub.patch_display()  # makes nice display for pandas tables
/home/runner/work/traces-sklearn/traces-sklearn/.pixi/envs/docs/bin/python: No module named pip
Note: you may need to restart the kernel to use updated packages.
from sklearn.datasets import make_blobs

X, y = make_blobs(n_samples=100, centers=[[0, 0], [1, 1]], random_state=42)
import numpy as np
import pandas as pd

X = pd.DataFrame(X, columns=["Feature #0", "Feature #1"])
class_names = np.array(["class #0", "class #1"])
y = pd.Series(class_names[y], name="Classes").astype("category")
data = pd.concat([X, y], axis=1)
data.plot.scatter(
    x="Feature #0",
    y="Feature #1",
    c="Classes",
    s=50,
    edgecolor="black",
)
plt.show()
../../_images/d6be5fc9cd68ea281cdedd65f5eeb42f065f2e3d48e155a54009a6d7ea5f3df8.png

Now, we train a decision tree classifier on this dataset. We first split the data into training and testing sets.

from sklearn.model_selection import train_test_split

data_train, data_test, X_train, X_test, y_train, y_test = train_test_split(
    data, X, y, random_state=42
)
from sklearn.tree import DecisionTreeClassifier

tree = DecisionTreeClassifier(max_depth=1)
tree.fit(X_train, y_train)
pred = tree.predict(X_test)
pred
array(['class #1', 'class #0', 'class #0', 'class #0', 'class #1',
       'class #1', 'class #0', 'class #0', 'class #0', 'class #0',
       'class #0', 'class #1', 'class #0', 'class #0', 'class #0',
       'class #1', 'class #0', 'class #1', 'class #0', 'class #0',
       'class #0', 'class #0', 'class #0', 'class #0', 'class #1'],
      dtype=object)

We plot the decision boundaries found using the training data.

from sklearn.inspection import DecisionBoundaryDisplay

display = DecisionBoundaryDisplay.from_estimator(tree, X_train, alpha=0.7)
data_train.plot.scatter(
    x="Feature #0", y="Feature #1", c="Classes", s=50, edgecolor="black", ax=display.ax_
)
plt.show()
../../_images/2c9749cb864c4674343d540fa4ff215d75704bde1e108b42d388fa528a46b4d7.png

Similarly, we get the following classification on the testing set.

display = DecisionBoundaryDisplay.from_estimator(tree, X_test, alpha=0.7)
data_test.plot.scatter(
    x="Feature #0", y="Feature #1", c="Classes", s=50, edgecolor="black", ax=display.ax_
)
plt.show()
../../_images/70cc0b2393c61eb479d734eb9bd4857988c19beb839626e8fe6b7dd66deef3b6.png

We see that the decision found with a decision tree is a simple binary split.

We can also plot the tree structure.

from sklearn.tree import plot_tree

plot_tree(tree, feature_names=X.columns, class_names=class_names, filled=True)
plt.show()
../../_images/368cc6bbda08bf889c64f7e171d0e97af0f5e36fe2253ac789833590697b2a07.png

EXERCISE

  1. Modify the depth of the tree and observe how the partitioning evolves.

  2. What can you conclude about under- and over-fitting of the tree model?

  3. How would you choose the best depth?

Many parameters control the complexity of a tree, but maximum depth is perhaps the easiest to understand. This parameter limits how finely the tree can partition the input space, or how many “if-else” questions it can ask before deciding which class a sample belongs to.

This parameter is crucial to tune for trees and tree-based models. The interactive plot below shows how underfitting and overfitting look for this model. A max_depth of 1 clearly underfits the model, while a depth of 7 or 8 clearly overfits. The maximum possible depth for this dataset is 8, at which point each leaf contains samples from only a single class. We call these leaves “pure.”

In the interactive plot below, blue and red colors indicate the predicted class for each region. The shade of color indicates the predicted probability for that class (darker = higher probability), while yellow regions indicate equal predicted probability for either class.

Note about partitioning in decision trees#

In this section, we examine in more detail how a tree selects the best partition. First, we use a real dataset instead of synthetic data.

dataset = pd.read_csv("../datasets/penguins.csv")
dataset = dataset.dropna(subset=["Body Mass (g)"])
dataset.head()
Processing column   1 / 17
Processing column   2 / 17
Processing column   3 / 17
Processing column   4 / 17
Processing column   5 / 17
Processing column   6 / 17
Processing column   7 / 17
Processing column   8 / 17
Processing column   9 / 17
Processing column  10 / 17
Processing column  11 / 17
Processing column  12 / 17
Processing column  13 / 17
Processing column  14 / 17
Processing column  15 / 17
Processing column  16 / 17
Processing column  17 / 17

Please enable javascript

The skrub table reports need javascript to display correctly. If you are displaying a report in a Jupyter notebook and you see this message, you may need to re-execute the cell or to trust the notebook (button on the top right or "File > Trust notebook").

We build a decision tree to classify penguin species using their body mass as a feature. To simplify the problem, we focus only on the Adelie and Gentoo species.

# Only select the column of interest
dataset = dataset[["Body Mass (g)", "Species"]]
# Make the species name more readable
dataset["Species"] = dataset["Species"].apply(lambda x: x.split()[0])
# Only select the Adelie and Gentoo penguins
dataset = dataset.set_index("Species").loc[["Adelie", "Gentoo"], :]
# Sort all penguins by their body mass
dataset = dataset.sort_values(by="Body Mass (g)")
# Convert the dataframe (2D) to a series (1D)
dataset = dataset.squeeze()
dataset
Species
Adelie    2850.0
Adelie    2850.0
Adelie    2900.0
Adelie    2900.0
Adelie    2900.0
           ...  
Gentoo    5950.0
Gentoo    6000.0
Gentoo    6000.0
Gentoo    6050.0
Gentoo    6300.0
Name: Body Mass (g), Length: 274, dtype: float64

First, we examine the body mass distribution for each species.

_, ax = plt.subplots()
dataset.groupby("Species").plot.hist(ax=ax, alpha=0.7, legend=True)
ax.set_ylabel("Frequency")
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
plt.show()
../../_images/115c9b41586f6bfab6676c67849a62be39e434c8365d80b050b605b0b5e02ea1.png

Instead of looking at the distribution, we can look at all samples directly.

%pip install seaborn
import seaborn as sns

ax = sns.swarmplot(x=dataset.values, y=[""] * len(dataset), hue=dataset.index)
ax.set_xlabel(dataset.name)
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
plt.show()
/home/runner/work/traces-sklearn/traces-sklearn/.pixi/envs/docs/bin/python: No module named pip
Note: you may need to restart the kernel to use updated packages.
../../_images/e0b92e1b15e722cc53b7adfb32a693901ea9079741af1aef495a923449c3b1fa.png

When we build a tree, we want to find splits that partition the data into groups that are as “unmixed” as possible. Let’s make a first completely random split to highlight the principle.

# create a random state so we all get the same results
rng = np.random.default_rng(42)
random_idx = rng.choice(dataset.size)

ax = sns.swarmplot(x=dataset.values, y=[""] * len(dataset), hue=dataset.index)
ax.set_xlabel(dataset.name)
ax.set_title(f"Body mass threshold: {dataset.iloc[random_idx]} grams")
ax.vlines(dataset.iloc[random_idx], -1, 1, color="red", linestyle="--")
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
plt.show()
../../_images/32f11c724b917be6df694ce8cdab161ace42d9969f4adb6ff3dcc6d730917ca9.png

After the split, we want two partitions where samples come from a single class as much as possible and contain as many samples as possible. Decision trees use a criterion to assess split quality. Entropy describes the class mixture in a partition. Let’s compute the entropy for the full dataset and the sets on each side of the split.

from scipy.stats import entropy

dataset.index.value_counts()

parent_entropy = entropy(dataset.index.value_counts(normalize=True))
parent_entropy

left_entropy = entropy(dataset[:random_idx].index.value_counts(normalize=True))
left_entropy

right_entropy = entropy(dataset[random_idx:].index.value_counts(normalize=True))
right_entropy
np.float64(0.6930191750980527)

We assess split quality by combining the entropies. This is called the information gain.

parent_entropy - (left_entropy + right_entropy)
np.float64(-0.005102495972248544)

However, we should normalize the entropies by the number of samples in each set.

def information_gain(labels_parent, labels_left, labels_right):
    # compute the entropies
    entropy_parent = entropy(labels_parent.value_counts(normalize=True))
    entropy_left = entropy(labels_left.value_counts(normalize=True))
    entropy_right = entropy(labels_right.value_counts(normalize=True))

    n_samples_parent = labels_parent.size
    n_samples_left = labels_left.size
    n_samples_right = labels_right.size

    # normalize with the number of samples
    normalized_entropy_left = (n_samples_left / n_samples_parent) * entropy_left
    normalized_entropy_right = (n_samples_right / n_samples_parent) * entropy_right

    return entropy_parent - normalized_entropy_left - normalized_entropy_right


information_gain(dataset.index, dataset[:random_idx].index, dataset[random_idx:].index)
np.float64(0.055599913525391065)

Now we compute the information gain for all possible body mass thresholds.

all_information_gain = pd.Series(
    [
        information_gain(dataset.index, dataset[:idx].index, dataset[idx:].index)
        for idx in range(dataset.size)
    ],
    index=dataset,
)
ax = all_information_gain.plot()
ax.set_ylabel("Information gain")
plt.show()
../../_images/b46d106ef0706821ad64dc3609c711bb821d8ac1f934e5452d345b3dcffaeab2.png
ax = (all_information_gain * -1).plot(color="red", label="Information gain")
ax = sns.swarmplot(x=dataset.values, y=[""] * len(dataset), hue=dataset.index)
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
plt.show()
../../_images/d61cb26c460083819d23192c4524b325080f255ac42894980823f65efb051ef5.png

The maximum information gain corresponds to the split that best partitions our data. Let’s check the corresponding body mass threshold.

all_information_gain.idxmax()

ax = (all_information_gain * -1).plot(color="red", label="Information gain")
ax = sns.swarmplot(x=dataset.values, y=[""] * len(dataset), hue=dataset.index)
ax.vlines(all_information_gain.idxmax(), -1, 1, color="red", linestyle="--")
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))

plt.show()
../../_images/d8c9ff50e0112a925a0314adc3dee3117c777b09f26bf8e4719c01388fc8ac08.png

Decision Tree Regression#

rnd = np.random.default_rng(42)
x = np.linspace(-3, 3, 100)
y_no_noise = np.sin(4 * x) + x
y = y_no_noise + rnd.normal(size=len(x))
X = x.reshape(-1, 1)
_, ax = plt.subplots()
ax.scatter(X, y, s=50)
ax.set(xlabel="Feature X", ylabel="Target y")
plt.show()
../../_images/17fbff667c0fee6bf48dd84c29921f264d783abfa6e2cc9e47ecbb92c3b2ce59.png
from sklearn.tree import DecisionTreeRegressor

reg = DecisionTreeRegressor(max_depth=2)
reg.fit(X, y)
DecisionTreeRegressor(max_depth=2)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
X_test = np.linspace(-3, 3, 1000).reshape((-1, 1))
y_test = reg.predict(X_test)

_, ax = plt.subplots()
ax.plot(X_test.ravel(), y_test, color="tab:blue", label="prediction")
ax.plot(X.ravel(), y, "C7.", label="training data")
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
plt.show()
../../_images/039f4b191df77c72efc700af947e28684bb6699c84d9d1433b6b5e663c862ec0.png

A single decision tree estimates the signal non-parametrically but has some issues. In some regions, the model shows high bias and underfits the data (seen in long flat lines that don’t follow data contours), while in other regions it shows high variance and overfits (seen in narrow spikes influenced by noise in single points).

EXERCISE

  1. Take the above example and repeat the training/testing by changing the tree depth.

  2. What can you conclude?

Other tree hyperparameters#

The max_depth hyperparameter controls overall tree complexity. This parameter works well when a tree is symmetric. However, trees are not guaranteed to be symmetric. In fact, optimal generalization might require some branches to grow deeper than others.

We build a dataset to illustrate this asymmetry. We generate a dataset with 2 subsets: one where the tree should find a clear separation and another where samples from both classes mix. This means a decision tree needs more splits to properly classify samples from the second subset than from the first subset.

from sklearn.datasets import make_blobs

feature_names = ["Feature #0", "Feature #1"]
target_name = "Class"

# Blobs that will be interlaced
X_1, y_1 = make_blobs(n_samples=300, centers=[[0, 0], [-1, -1]], random_state=0)
# Blobs that will be easily separated
X_2, y_2 = make_blobs(n_samples=300, centers=[[3, 6], [7, 0]], random_state=0)

X = np.concatenate([X_1, X_2], axis=0)
y = np.concatenate([y_1, y_2])
data = np.concatenate([X, y[:, np.newaxis]], axis=1)
data = pd.DataFrame(data, columns=feature_names + [target_name])
data[target_name] = data[target_name].astype(np.int64).astype("category")
data
Processing column   1 / 3
Processing column   2 / 3
Processing column   3 / 3

Please enable javascript

The skrub table reports need javascript to display correctly. If you are displaying a report in a Jupyter notebook and you see this message, you may need to re-execute the cell or to trust the notebook (button on the top right or "File > Trust notebook").

_, ax = plt.subplots(figsize=(10, 8))
data.plot.scatter(
    x="Feature #0",
    y="Feature #1",
    c="Class",
    s=100,
    edgecolor="black",
    ax=ax,
)
ax.set_title("Synthetic dataset")
plt.show()
../../_images/95b108a94d007612fbc6407c60affecb9c892d9491cbc55c6fcd8ae948ff57c8.png

First, we train a shallow decision tree with max_depth=2. This depth should suffice to separate the easily separable blobs.

max_depth = 2
tree = DecisionTreeClassifier(max_depth=max_depth)
tree.fit(X, y)

_, ax = plt.subplots(figsize=(10, 8))
DecisionBoundaryDisplay.from_estimator(tree, X, cmap=plt.cm.RdBu, ax=ax)
data.plot.scatter(
    x="Feature #0",
    y="Feature #1",
    c="Class",
    s=100,
    edgecolor="black",
    ax=ax,
)
ax.set_title(f"Decision tree with max-depth of {max_depth}")
plt.show()
../../_images/915a3ddd870da414163d1f70e46706b9fec9b942196e63b52feb606baf1170fe.png

As expected, the blue blob on the right and red blob on top separate easily. However, we need more splits to better separate the mixed blobs.

The red blob on top and blue blob on the right separate perfectly. However, the tree still makes mistakes where the blobs mix. Let’s examine the tree structure.

_, ax = plt.subplots(figsize=(15, 8))
plot_tree(
    tree, feature_names=feature_names, class_names=class_names, filled=True, ax=ax
)
plt.show()
../../_images/99e8e64561c8afcd0b7c17d9bdafa636627e965b84711e2733ebb559d661a2a0.png

The right branch achieves perfect classification. Now we increase the depth to see how the tree grows.

max_depth = 6
tree = DecisionTreeClassifier(max_depth=max_depth)
tree.fit(X, y)
DecisionTreeClassifier(max_depth=6)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
_, ax = plt.subplots(figsize=(10, 8))
DecisionBoundaryDisplay.from_estimator(tree, X, cmap=plt.cm.RdBu, ax=ax)
data.plot.scatter(
    x="Feature #0",
    y="Feature #1",
    c="Class",
    s=100,
    edgecolor="black",
    ax=ax,
)
ax.set_title(f"Decision tree with max-depth of {max_depth}")
plt.show()
../../_images/123666be4a6ecd9915399f3690e61bff801b78ab6002c09c64b9a3fe6a773720.png
_, ax = plt.subplots(figsize=(25, 15))
plot_tree(
    tree, feature_names=feature_names, class_names=class_names, filled=True, ax=ax
)
plt.show()
../../_images/22116a2ba1f3c52efff4f820495cea4b7818328d7494696cd8da934165e414a3.png

As expected, the left branch continues to grow while the right branch stops splitting. Setting max_depth cuts the tree horizontally at a specific level, whether or not a branch would benefit from growing further.

The hyperparameters min_samples_leaf, min_samples_split, max_leaf_nodes, and min_impurity_decrease allow asymmetric trees and apply constraints at the leaf or node level. Let’s examine the effect of min_samples_leaf.

min_samples_leaf = 20
tree = DecisionTreeClassifier(min_samples_leaf=min_samples_leaf)
tree.fit(X, y)

_, ax = plt.subplots(figsize=(10, 8))
DecisionBoundaryDisplay.from_estimator(tree, X, cmap=plt.cm.RdBu, ax=ax)
data.plot.scatter(
    x="Feature #0",
    y="Feature #1",
    c="Class",
    s=100,
    edgecolor="black",
    ax=ax,
)
ax.set_title(f"Decision tree with leaf having at least {min_samples_leaf} samples")
plt.show()
../../_images/704a6e6c69ccc4d263f1f03dd593c47c2b16861500600c001c7d3da6c1b6550f.png
_, ax = plt.subplots(figsize=(15, 15))
plot_tree(
    tree, feature_names=feature_names, class_names=class_names, filled=True, ax=ax
)
plt.show()
../../_images/7a6a058578cd3a0e9151fdefb60d4406bb286739958b123b38ceae71ae934a12.png

This hyperparameter ensures leaves contain a minimum number of samples and prevents further splits otherwise. These hyperparameters offer an alternative to the max_depth parameter.