Visualize Decision Tree
The Decision Tree algorithm's structure is human-readable, a key advantage. In this notebook, we fit a Decision Tree model using Python's scikit-learn
and visualize it with matplotlib
. This showcases the power of decision-tree visualization.
MLJAR Studio is Python code editior with interactive code recipes and local AI assistant.
You have code recipes UI displayed at the top of code cells.
At the top of the notebook, you have a list of required modules to run notebook's code. Please note, that you don't have to manually copy the code with imports, because MLJAR Studio will automatially import required packages when using code recipes. For each code recipe, there is a small side note, with the name of the cookbook used, so you can easily navigate to the recipe.
# import packages
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
We load sample data set Iris. It is small data set with 150 samples, 4 features describing petal and sepal of flowers and 5-th column assigning sample to one of three classes.
# load example dataset
df = pd.read_csv(
"https://raw.githubusercontent.com/pplonski/datasets-for-start/master/iris/data.csv",
skipinitialspace=True,
)
# display first rows
df.head()
We split data set horizontally to X and y. The X variable represents the input features. The y is target vector, that model will learn and then predict.
# create X columns list and set y column
x_cols = [
"sepal length (cm)",
"sepal width (cm)",
"petal length (cm)",
"petal widght (cm)",
]
y_col = "class"
# set input matrix
X = df[x_cols]
# set target vector
y = df[y_col]
# display data shapes
print(f"X shape is {X.shape}")
print(f"y shape is {y.shape}")
Let's train Decision Tree. The first step is to create object with DecisionTreeClassifier
class. We set hyper parameters during object initialization.
# initialize Decision Tree
my_tree = DecisionTreeClassifier(criterion="gini", random_state=42)
# display model card
my_tree
The my_tree
object needs to be fitted on data. We will use here X
and y
data.
# fit model
my_tree.fit(X, y)
Visualize Decision Tree
The tree structure is displayed using matplotlib
library. Please change figsize
paramter if you would like to have larger image.
# create large empty figure
fig = plt.figure(figsize=(25, 20))
# plot tree
_ = plot_tree(
my_tree,
feature_names=X.columns.tolist(),
class_names=np.unique(y).tolist(),
max_depth=5,
filled=True,
)
Conclusions
Decision Tree is a very useful algorithm. It can be used to predict new values. What is more, it provides a great way to explain why such a value was predicted. Decision Tree visualization is available in the scikit-learn
library.
Recipes used in the python-visualize-decision-tree.ipynb
All code recipes used in this notebook are listed below. You can click them to check their documentation.
Packages used in the python-visualize-decision-tree.ipynb
List of packages that need to be installed in your Python environment to run this notebook. Please note that MLJAR Studio automatically installs and imports required modules for you.
pandas>=1.0.0
scikit-learn>=1.5.0
Similar notebooks
List of similar Python notebooks, so you can find more inspiration 😊