Train Decision Tree on Iris data set

Python is a great choice for Machine Learning projects, because of rich ML packages ecosystem. The scikit-learn package provides implementation of Decision Tree algorithm. Let's train Decision Tree classifier using Iris dataset.

This notebook was created with MLJAR Studio

MLJAR Studio is Python code editior with interactive code recipes and local AI assistant.
You have code recipes UI displayed at the top of code cells.

Documentation

The data set consists of 50 samples from each of three species of Iris (Iris setosa, Iris virginica and Iris versicolor). There are 4 features measured for each sample: the length and the width of the sepals and petals.

Contents:

  1. Load Iris data set.
  2. Split data into train and test subsets.
  3. Select X and y for model training.
  4. Create Decision Tree model.
  5. Fit Decision Tree classifier using train data.
  6. Compute predictions on test data.
  7. Compute predictions accuracy by comparing dround truth labels with predicted.

Please note that MLJAR Studio automatically imports required packages for you :)

# import packages
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# load example dataset
df = pd.read_csv(
    "https://raw.githubusercontent.com/pplonski/datasets-for-start/master/iris/data.csv",
    skipinitialspace=True,
)
# display first rows
df.head()

Loaded data set is splitted into training and testing subsets. We use 75% of data fo train.

# split data
train, test = train_test_split(df, train_size=0.75, shuffle=True, random_state=42)
# display data shapes
print(f"All data shape {df.shape}")
print(f"Train shape {train.shape}")
print(f"Test shape {test.shape}")

Select X and y for train and test data sets.

# create X columns list and set y column
x_cols = [
    "sepal length (cm)",
    "sepal width (cm)",
    "petal length (cm)",
    "petal widght (cm)",
]
y_col = "class"
# set input matrix
X_train = train[x_cols]
# set target vector
y_train = train[y_col]
# display data shapes
print(f"X_train shape is {X_train.shape}")
print(f"y_train shape is {y_train.shape}")
# create X columns list and set y column
x_cols = [
    "sepal length (cm)",
    "sepal width (cm)",
    "petal length (cm)",
    "petal widght (cm)",
]
y_col = "class"
# set input matrix
X_test = test[x_cols]
# set target vector
y_test = test[y_col]
# display data shapes
print(f"X_test shape is {X_test.shape}")
print(f"y_test shape is {y_test.shape}")

Initialize and fit Decision Tree model.

# initialize Decision Tree
my_tree = DecisionTreeClassifier(criterion="gini", random_state=42)
# display model card
my_tree
# fit model
my_tree.fit(X_train, y_train)
# compute prediction
predicted = my_tree.predict(X_test)
print("Predictions")
print(predicted)

# predict class probabilities
predicted_proba = my_tree.predict_proba(X_test)
print("Predicted class probabilities")
print(predicted_proba)
# compute metric
metric_accuracy = accuracy_score(y_test, predicted)
print(f"Accuracy: {metric_accuracy}")

Conclusions

The accuracy of trained Decision Tree is 1.0, which means that all samples are correctly classified. The reason for such good result is that Iris data set is very easy. In real life projects, with complex data, the accuracy should be below 1.0. What is more, 1.0 accuracy very often will mean that something is wrong: there is data leak and model is overfitting. What is good accuracy in Machine Learning? There are situations when 0.51 accuracy is OK, for example in financial data. On the other hand in medical context even 0.9 accuracy is not enough. It all depends on the use case and data complexity.

Recipes used in the train-decision-tree-iris-dataset.ipynb

All code recipes used in this notebook are listed below. You can click them to check their documentation.

Packages used in the train-decision-tree-iris-dataset.ipynb

List of packages that need to be installed in your Python environment to run this notebook. Please note that MLJAR Studio automatically installs and imports required modules for you.

pandas>=1.0.0

scikit-learn>=1.5.0