Extract Rules from Decision Tree in 3 Ways with Scikit-Learn and Python
The rules extraction from the Decision Tree can help with better understanding how samples propagate through the tree during the prediction. It can be needed if we want to implement a Decision Tree without Scikit-learn or different than Python language. Decision Trees are easy to move to any programming language because there are set of if-else
statements. I’ve seen many examples of moving scikit-learn Decision Trees into C, C++, Java, or even SQL.
In this post, I will show you 3 ways how to get decision rules from the Decision Tree (for both classification and regression tasks) with following approaches:
- built-in text representation,
- convert a Decision Tree to the code (can be in any programming language)
- convert a Decision Tree to set of rules which are human-readable (my favourite approach)
If you would like to visualize your Decision Tree model, then you should see my article Visualize a Decision Tree in 5 Ways with Scikit-Learn and Python
If you want to train Decision Tree and other ML algorithms (Random Forest, Neural Networks, Xgboost, CatBoost, LighGBM) in an automated way, you should check our open-source AutoML Python Package on the GitHub: mljar-supervised
Train Decision Tree on Classification Task
Let's train a DecisionTreeClassifier
on the iris
dataset. I will use default hyper-parameters for the classifier, except the max_depth=3
(don't want too deep trees, for readability reasons).
from matplotlib import pyplot as plt
from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
# Prepare the data data
iris = datasets.load_iris()
X = iris.data
y = iris.target
# Fit the classifier with max_depth=3
clf = DecisionTreeClassifier(max_depth=3, random_state=1234)
model = clf.fit(X, y)
Scikit-Learn Built-in Text Representation
The Scikit-Learn Decision Tree class has an export_text()
. It returns the text representation of the rules.
# get the text representation
text_representation = tree.export_text(clf)
print(text_representation)
The output:
|--- feature_2 <= 2.45
| |--- class: 0
|--- feature_2 > 2.45
| |--- feature_3 <= 1.75
| | |--- feature_2 <= 4.95
| | | |--- class: 1
| | |--- feature_2 > 4.95
| | | |--- class: 2
| |--- feature_3 > 1.75
| | |--- feature_2 <= 4.85
| | | |--- class: 2
| | |--- feature_2 > 4.85
| | | |--- class: 2
You can pass the feature names as the argument to get better text representation:
text_representation = tree.export_text(clf, feature_names=iris.feature_names)
print(text_representation)
The output, with our feature names instead of generic feature_0
, feature_1
, ... :
|--- petal length (cm) <= 2.45
| |--- class: 0
|--- petal length (cm) > 2.45
| |--- petal width (cm) <= 1.75
| | |--- petal length (cm) <= 4.95
| | | |--- class: 1
| | |--- petal length (cm) > 4.95
| | | |--- class: 2
| |--- petal width (cm) > 1.75
| | |--- petal length (cm) <= 4.85
| | | |--- class: 2
| | |--- petal length (cm) > 4.85
| | | |--- class: 2
Extract Code Rules
There isn't any built-in method for extracting the if-else
code rules from the Scikit-Learn tree. We need to write it. The code below is based on StackOverflow answer - updated to Python 3.
from sklearn.tree import _tree
def tree_to_code(tree, feature_names):
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
feature_names = [f.replace(" ", "_")[:-5] for f in feature_names]
print("def predict({}):".format(", ".join(feature_names)))
def recurse(node, depth):
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
print("{}if {} <= {}:".format(indent, name, np.round(threshold,2)))
recurse(tree_.children_left[node], depth + 1)
print("{}else: # if {} > {}".format(indent, name, np.round(threshold,2)))
recurse(tree_.children_right[node], depth + 1)
else:
print("{}return {}".format(indent, tree_.value[node]))
recurse(0, 1)
The above code recursively walks through the nodes in the tree and prints out decision rules. The rules are presented as python function. The below predict()
code was generated with tree_to_code()
. You can easily adapt the above code to produce decision rules in any programming language.
def predict(sepal_length, sepal_width, petal_length, petal_width):
if petal length (cm) <= 2.45:
return [[50. 0. 0.]]
else: # if petal length (cm) > 2.45
if petal width (cm) <= 1.75:
if petal length (cm) <= 4.95:
return [[ 0. 47. 1.]]
else: # if petal length (cm) > 4.95
return [[0. 2. 4.]]
else: # if petal width (cm) > 1.75
if petal length (cm) <= 4.85:
return [[0. 1. 2.]]
else: # if petal length (cm) > 4.85
return [[ 0. 0. 43.]]
Extract Human-Readable Rules
The code-rules from the previous example are rather computer-friendly than human-friendly. Let's update the code to obtain nice to read text-rules.
def get_rules(tree, feature_names, class_names):
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
paths = []
path = []
def recurse(node, path, paths):
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
p1, p2 = list(path), list(path)
p1 += [f"({name} <= {np.round(threshold, 3)})"]
recurse(tree_.children_left[node], p1, paths)
p2 += [f"({name} > {np.round(threshold, 3)})"]
recurse(tree_.children_right[node], p2, paths)
else:
path += [(tree_.value[node], tree_.n_node_samples[node])]
paths += [path]
recurse(0, path, paths)
# sort by samples count
samples_count = [p[-1][1] for p in paths]
ii = list(np.argsort(samples_count))
paths = [paths[i] for i in reversed(ii)]
rules = []
for path in paths:
rule = "if "
for p in path[:-1]:
if rule != "if ":
rule += " and "
rule += str(p)
rule += " then "
if class_names is None:
rule += "response: "+str(np.round(path[-1][0][0][0],3))
else:
classes = path[-1][0][0]
l = np.argmax(classes)
rule += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
rule += f" | based on {path[-1][1]:,} samples"
rules += [rule]
return rules
Run the function with clf
classifier:
rules = get_rules(clf, iris.feature_names, iris.target_names)
for r in rules:
print(r)
The output produced by the get_rules()
:
if (petal length (cm) <= 2.45) then class: setosa (proba: 100.0%) | based on 50 samples
if (petal length (cm) > 2.45) and (petal width (cm) <= 1.75) and (petal length (cm) <= 4.95) then class: versicolor (proba: 97.92%) | based on 48 samples
if (petal length (cm) > 2.45) and (petal width (cm) > 1.75) and (petal length (cm) > 4.85) then class: virginica (proba: 100.0%) | based on 43 samples
if (petal length (cm) > 2.45) and (petal width (cm) <= 1.75) and (petal length (cm) > 4.95) then class: virginica (proba: 66.67%) | based on 6 samples
if (petal length (cm) > 2.45) and (petal width (cm) > 1.75) and (petal length (cm) <= 4.85) then class: virginica (proba: 66.67%) | based on 3 samples
The rules are sorted by the number of training samples assigned to each rule. For each rule, there is information about the predicted class name and probability of prediction.
Extract Rules in Regression Task
Let's check rules for DecisionTreeRegressor
. I will use boston
dataset to train model, again with max_depth=3
.
from sklearn import datasets
from sklearn.tree import DecisionTreeRegressor
from sklearn import tree
# Prepare the data data
boston = datasets.load_boston()
X = boston.data
y = boston.target
# Fit the regressor, set max_depth = 3
regr = DecisionTreeRegressor(max_depth=3, random_state=1234)
model = regr.fit(X, y)
# Print rules
rules = get_rules(regr, boston.feature_names, None)
for r in rules:
print(r)
The printed rules:
if (RM <= 6.941) and (LSTAT <= 14.4) and (DIS > 1.385) then response: 22.905 | based on 250 samples
if (RM <= 6.941) and (LSTAT > 14.4) and (CRIM <= 6.992) then response: 17.138 | based on 101 samples
if (RM <= 6.941) and (LSTAT > 14.4) and (CRIM > 6.992) then response: 11.978 | based on 74 samples
if (RM > 6.941) and (RM <= 7.437) and (NOX <= 0.659) then response: 33.349 | based on 43 samples
if (RM > 6.941) and (RM > 7.437) and (PTRATIO <= 19.65) then response: 45.897 | based on 29 samples
if (RM <= 6.941) and (LSTAT <= 14.4) and (DIS <= 1.385) then response: 45.58 | based on 5 samples
if (RM > 6.941) and (RM <= 7.437) and (NOX > 0.659) then response: 14.4 | based on 3 samples
if (RM > 6.941) and (RM > 7.437) and (PTRATIO > 19.65) then response: 21.9 | based on 1 samples
Summary
There are many ways to present a Decision Tree. It can be visualized as a graph or converted to the text representation. In the MLJAR AutoML we are using dtreeviz
visualization and text representation with human-friendly format. If you would like to train a Decision Tree (or other ML algorithms) you can try MLJAR AutoML: https://github.com/mljar/mljar-supervised.