Visualize Decision Tree

The great advantage of Decision Tree algorithm is that the model structure (tree) is human readable. In this notebook, we will fit a Decision Tree model and visualize it using matplotlib package.

# import packages
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree

We load sample data set Iris. It is small data set with 150 samples, 4 features describing petal and sepal of flowers and 5-th column assigning sample to one of three classes.

# load example dataset
df = pd.read_csv(
    "https://raw.githubusercontent.com/pplonski/datasets-for-start/master/iris/data.csv",
    skipinitialspace=True,
)
# display first rows
df.head()

We split data set horizontally to X and y. The X variable represents the input features. The y is target vector, that model will learn and then predict.

# create X columns list and set y column
x_cols = [
    "sepal length (cm)",
    "sepal width (cm)",
    "petal length (cm)",
    "petal widght (cm)",
]
y_col = "class"
# set input matrix
X = df[x_cols]
# set target vector
y = df[y_col]
# display data shapes
print(f"X shape is {X.shape}")
print(f"y shape is {y.shape}")

Let's train Decision Tree. The first step is to create object with DecisionTreeClassifier class. We set hyper parameters during object initialization.

# initialize Decision Tree
my_tree = DecisionTreeClassifier(criterion="gini", random_state=42)
# display model card
my_tree

The my_tree object needs to be fitted on data. We will use here X and y data.

# fit model
my_tree.fit(X, y)

Visualize Decision Tree

The tree structure is displayed using matplotlib library. Please change figsize paramter if you would like to have larger image.

# create large empty figure
fig = plt.figure(figsize=(25, 20))
# plot tree
_ = plot_tree(
    my_tree,
    feature_names=X.columns.tolist(),
    class_names=np.unique(y).tolist(),
    max_depth=5,
    filled=True,
)

Conclusions

Decision Tree is a very useful algorithm. It can be used to predict new values and provides a great way to explain why such a value was predicted. Decision Tree visualization is available in the scikit-learn library.

Recipes used in the python-visualize-decision-tree.ipynb

All code recipes used in this notebook are listed below. You can click them to check their documentation.

Packages used in the python-visualize-decision-tree.ipynb

List of packages that need to be installed in your Python environment to run this notebook. Please note that MLJAR Studio automatically installs and imports required modules for you.

pandas>=1.0.0

scikit-learn>=1.5.0

matplotlib>=3.8.4