Save and load Decision Tree
Scikit-learn
provides Decision Tree algorithms for classification (DecisionTreeClassifier
) and regression (DecisionTreeRegressor
). We'll train a classifier on the Iris dataset and save it using pickle
. Learn to load and use the model for predictions, ensuring consistency between trained and loaded models.
MLJAR Studio is Python code editior with interactive code recipes and local AI assistant.
You have code recipes UI displayed at the top of code cells.
Usually, the first code cell has imports of required packages. Please note, that all packages are automatically imported by MLJAR Studio, so you don't have to copy and paste any code.
# import packages
import pandas as pd
import pickle
from sklearn.tree import DecisionTreeClassifier
Load sample data
Let's load Iris data set to pandas DataFrame from URL https://github.com/pplonski/datasets-for-start
# load example dataset
df = pd.read_csv(
"https://raw.githubusercontent.com/pplonski/datasets-for-start/master/iris/data.csv",
skipinitialspace=True,
)
# display first rows
df.head()
Select X and y
We need to select training features for X input matrix and target y vector.
# create X columns list and set y column
x_cols = [
"sepal length (cm)",
"sepal width (cm)",
"petal length (cm)",
"petal widght (cm)",
]
y_col = "class"
# set input matrix
X = df[x_cols]
# set target vector
y = df[y_col]
# display data shapes
print(f"X shape is {X.shape}")
print(f"y shape is {y.shape}")
Create Decision Tree object
The first step is to create object for Decision Tree model. In this step, we can set hyper parameters.
# initialize Decision Tree
my_tree = DecisionTreeClassifier(criterion="gini", random_state=42)
# display model card
my_tree
Fit Decision Tree model
The model training is performed with fit()
method. Please note that the output box with model card changed the color, from orange (unfitted) to blue (fitted).
# fit model
my_tree.fit(X, y)
Save Decision Tree to pickle
The pickle
module can be used to save any Python object to hard drive. Let's use it to save our Decision Tree model.
# save object to pickle file
with open(r"decision-tree-model.pickle", "wb") as fout:
pickle.dump(my_tree, fout)
print(f"Object my_tree saved at decision-tree-model.pickle")
Load Decision Tree from pickle
Let's load the model from the pickle
file. Please note, that we change the name of the object.
Now we have two objects with Decision Tree models: my_tree
and tree_loaded
.
# open pickle file and load
with open(r"decision-tree-model.pickle", "rb") as fin:
tree_loaded = pickle.load(fin)
# display loaded object
print(tree_loaded)
Compute predictions and compare models
Let's compute predictions from the first model (my_tree
) and then from loaded model (tree_loaded
).
# compute prediction
predicted = my_tree.predict(X)
print("Predictions")
print(predicted)
# predict class probabilities
predicted_proba = my_tree.predict_proba(X)
print("Predicted class probabilities")
print(predicted_proba)
# compute prediction
predicted_from_loaded = tree_loaded.predict(X)
print("Predictions")
print(predicted_from_loaded)
# predict class probabilities
predicted_from_loaded_proba = tree_loaded.predict_proba(X)
print("Predicted class probabilities")
print(predicted_from_loaded_proba)
Conclusions
Saving and loading Decision Tree models from scikit-learn
library is very easy. The pickle
library provides dump()
and load()
methods. You might want to save Decision Tree for using it in production. The file with model is loaded on prediction server and predictions can be computed on new data.
Recipes used in the decision-tree-save-and-load.ipynb
All code recipes used in this notebook are listed below. You can click them to check their documentation.
Packages used in the decision-tree-save-and-load.ipynb
List of packages that need to be installed in your Python environment to run this notebook. Please note that MLJAR Studio automatically installs and imports required modules for you.
pandas>=1.0.0
scikit-learn>=1.5.0
Similar notebooks
List of similar Python notebooks, so you can find more inspiration ๐