Visualize Decision Tree

The Decision Tree algorithm's structure is human-readable, a key advantage. In this notebook, we fit a Decision Tree model using Python's scikit-learn and visualize it with matplotlib. This showcases the power of decision-tree visualization.

This notebook was created with MLJAR Studio

MLJAR Studio is Python code editior with interactive code recipes and local AI assistant.
You have code recipes UI displayed at the top of code cells.

Documentation

At the top of the notebook, you have a list of required modules to run notebook's code. Please note, that you don't have to manually copy the code with imports, because MLJAR Studio will automatially import required packages when using code recipes. For each code recipe, there is a small side note, with the name of the cookbook used, so you can easily navigate to the recipe.

# import packages
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree

We load sample data set Iris. It is small data set with 150 samples, 4 features describing petal and sepal of flowers and 5-th column assigning sample to one of three classes.

# load example dataset
df = pd.read_csv(
    "https://raw.githubusercontent.com/pplonski/datasets-for-start/master/iris/data.csv",
    skipinitialspace=True,
)
# display first rows
df.head()

We split data set horizontally to X and y. The X variable represents the input features. The y is target vector, that model will learn and then predict.

# create X columns list and set y column
x_cols = [
    "sepal length (cm)",
    "sepal width (cm)",
    "petal length (cm)",
    "petal widght (cm)",
]
y_col = "class"
# set input matrix
X = df[x_cols]
# set target vector
y = df[y_col]
# display data shapes
print(f"X shape is {X.shape}")
print(f"y shape is {y.shape}")

Let's train Decision Tree. The first step is to create object with DecisionTreeClassifier class. We set hyper parameters during object initialization.

# initialize Decision Tree
my_tree = DecisionTreeClassifier(criterion="gini", random_state=42)
# display model card
my_tree

The my_tree object needs to be fitted on data. We will use here X and y data.

# fit model
my_tree.fit(X, y)

Visualize Decision Tree

The tree structure is displayed using matplotlib library. Please change figsize paramter if you would like to have larger image.

# create large empty figure
fig = plt.figure(figsize=(25, 20))
# plot tree
_ = plot_tree(
    my_tree,
    feature_names=X.columns.tolist(),
    class_names=np.unique(y).tolist(),
    max_depth=5,
    filled=True,
)

Conclusions

Decision Tree is a very useful algorithm. It can be used to predict new values. What is more, it provides a great way to explain why such a value was predicted. Decision Tree visualization is available in the scikit-learn library.

Recipes used in the python-visualize-decision-tree.ipynb

All code recipes used in this notebook are listed below. You can click them to check their documentation.

Packages used in the python-visualize-decision-tree.ipynb

List of packages that need to be installed in your Python environment to run this notebook. Please note that MLJAR Studio automatically installs and imports required modules for you.

pandas>=1.0.0

scikit-learn>=1.5.0

matplotlib>=3.8.4