The great advantage of Decision Tree algorithm is that the model structure (tree) is human readable. In this notebook, we will fit a Decision Tree model and visualize it using matplotlib package.

# import packages
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree

We load sample data set Iris. It is small data set with 150 samples, 4 features describing petal and sepal of flowers and 5-th column assigning sample to one of three classes.

# load example dataset
df = pd.read_csv(
# display first rows

We split data set horizontally to X and y. The X variable represents the input features. The y is target vector, that model will learn and then predict.

# create X columns list and set y column
x_cols = [
    "sepal length (cm)",
    "sepal width (cm)",
    "petal length (cm)",
    "petal widght (cm)",
y_col = "class"
# set input matrix
X = df[x_cols]
# set target vector
y = df[y_col]
# display data shapes
print(f"X shape is {X.shape}")
print(f"y shape is {y.shape}")

Let's train Decision Tree. The first step is to create object with DecisionTreeClassifier class. We set hyper parameters during object initialization.

# initialize Decision Tree
my_tree = DecisionTreeClassifier(criterion="gini", random_state=42)
# display model card

The my_tree object needs to be fitted on data. We will use here X and y data.

# fit model, y)

Visualize Decision Tree

The tree structure is displayed using matplotlib library. Please change figsize paramter if you would like to have larger image.

# create large empty figure
fig = plt.figure(figsize=(25, 20))
# plot tree
_ = plot_tree(


Decision Tree is a very useful algorithm. It can be used to predict new values and provides a great way to explain why such a value was predicted. Decision Tree visualization is available in the scikit-learn library.

