Cover image for Classification and Regression Analysis with Decision Trees
Next Tech

Classification and Regression Analysis with Decision Trees

lorrli274 profile image Lorraine Originally published at next.tech ・8 min read

A decision tree is a supervised machine learning model used to predict a target by learning decision rules from features. As the name suggests, we can think of this model as breaking down our data by making a decision based on asking a series of questions.

Let's consider the following example in which we use a decision tree to decide upon an activity on a particular day:

Based on the features in our training set, the decision tree model learns a series of questions to infer the class labels of the samples. As we can see, decision trees are attractive models if we care about interpretability.

Although the preceding figure illustrates the concept of a decision tree based on categorical variables (classification), the same concept applies if our features are real numbers (regression).

In this tutorial, we will discuss how to build a decision tree model with Python’s scikit-learn library. We will cover:

  • The fundamental concepts of decision trees
  • The mathematics behind the decision tree learning algorithm
  • Information gain and impurity measures
  • Classification trees
  • Regression trees

Let’s get started!

This tutorial is adapted from Next Tech’s Python Machine Learning series which takes you through machine learning and deep learning algorithms with Python from 0 to 100. It includes an in-browser sandboxed environment with all the necessary software and libraries pre-installed, and projects using public datasets. You can get started here!

The Fundamentals of Decision Trees

A decision tree is constructed by recursive partitioning — starting from the root node (known as the first parent), each node can be split into left and right child nodes. These nodes can then be further split and they themselves become parent nodes of their resulting children nodes.

For example, looking at the image above, the root node is Work to do? and splits into the child nodes Stay in and Outlook based on whether or not there is work to do. The Outlook node further splits into three child nodes.

So, how do we know what the optimal splitting point is at each node?

Starting from the root, the data is split on the feature that results in the largest Information Gain (IG) (explained in more detail below). In an iterative process, we then repeat this splitting procedure at each child node until the leaves are pure — i.e. samples at each node all belong to the same class.

In practice, this can result in a very deep tree with many nodes, which can easily lead to overfitting. Thus, we typically want to prune the tree by setting a limit for the maximal depth of the tree.

Maximizing Information Gain

In order to split the nodes at the most informative features, we need to define an objective function that we want to optimize via the tree learning algorithm. Here, our objective function is to maximize the information gain at each split, which we define as follows:

Here, f is the feature to perform the split, Dp, Dleft, and Dright are the dataset of the parent and child nodes, I is the impurity measure, Np is the total number of samples at the parent node, and Nleft and Nright are the number of samples in the child nodes.

We will discuss impurity measures for classification and regression decision trees in more detail in our examples below. But for now, just understand that information gain is simply the difference between the impurity of the parent node and the sum of the child node impurities — the lower the impurity of the child nodes, the larger the information gain.

Note that the above equation is for binary decision trees — each parent node is split into two child nodes only. If you have a decision tree with multiple nodes, you would simply sum the impurity of all nodes.

Classification Trees

We will start by talking about classification decision trees (also known as classification trees). For this example, we will be using the Iris dataset, a classic in the field of machine learning. It contains the measurements of 150 Iris flowers from three different species —Setosa, Versicolor, and Virginica. These will be our targets. Our goal is to predict which category an Iris flower belongs to. The petal length and width in centimeters are stored as columns, which we also call the features of the dataset.

Let’s first import the dataset and assign the features as X and the target as y:

from sklearn import datasets

iris = datasets.load_iris()                        # Load iris dataset

X = iris.data[:, [2, 3]]                           # Assign matrix X
y = iris.target                                    # Assign vector y

Using scikit-learn, we will now train a decision tree with a maximum depth of 4. The code is as follows:

from sklearn.tree import DecisionTreeClassifier    # Import decision tree classifier model

tree = DecisionTreeClassifier(criterion='entropy', # Initialize and fit classifier
    max_depth=4, random_state=1)
tree.fit(X, y)

Notice that we set the criterion as ‘entropy’. This criterion is known as the impurity measure (mentioned in the previous section). In classification, entropy is the most common impurity measure or splitting criteria. It is defined by:

Here, P(i|t) is the proportion of the samples that belong to class c for a particular node t. The entropy is therefore 0 if all samples at a node belong to the same class, and the entropy is maximal if we have a uniform class distribution.

For a more visual understanding of entropy, let’s plot the impurity index for the probability range [0, 1] for class 1. The code is as follows:

import numpy as np
import matplotlib.pyplot as plt

def entropy(p):
    return - p * np.log2(p) - (1 - p) * np.log2(1 - p)

x = np.arange(0.0, 1.0, 0.01)                      # Create dummy data
e = [entropy(p) if p != 0 else None for p in x]    # Calculate entropy

plt.plot(x, e, label='entropy', color='r')         # Plot impurity indices
for y in [0.5, 1.0]:
    plt.axhline(y=y, linewidth=1,
                color='k', linestyle='--')
plt.ylabel('Impurity Index')

As you can see, entropy is 0 if p(i=1|t) = 1. If the classes are distributed uniformly with p(i=1|t) = 0.5, entropy is 1.

Now, returning to our Iris example, we will visualize our trained classification tree and see how entropy decides each split.

A nice feature in scikit-learn is that it allows us to export the decision tree as a .dot file after training, which we can visualize using GraphViz, for example. In addition to GraphViz, we will use a Python library called pydotplus, which has capabilities similar to GraphViz and allows us to convert .dot data files into a decision tree image file.

You can install pydotplus and graphviz by executing the following commands in your Terminal:

pip3 install pydotplus
apt install graphviz

The following code will create an image of our decision tree in PNG format:

from pydotplus.graphviz import graph_from_dot_data
from sklearn.tree import export_graphviz

dot_data = export_graphviz(                           # Create dot data
    tree, filled=True, rounded=True,
    class_names=['Setosa', 'Versicolor','Virginica'],
    feature_names=['petal length', 'petal width'],

graph = graph_from_dot_data(dot_data)                 # Create graph from dot data
graph.write_png('tree.png')                           # Write graph to PNG image

Looking at the resulting decision tree figure saved in the image file tree.png, we can now nicely trace back the splits that the decision tree determined from our training dataset. We started with 150 samples at the root and split them into two child nodes with 50 and 100 samples, using the petal width cut-off ≤ 1.75 cm. After the first split, we can see that the left child node is already pure and only contains samples from the setosa class (entropy = 0). The further splits on the right are then used to separate the samples from the versicolor and virginica class.

Looking at the final entropy we see that the decision tree with a depth of 4 does a very good job of separating the flower classes.

Regression Trees

We will be using the Boston Housing dataset for our regression example. This is another very popular dataset which contains information about houses in the suburbs of Boston. There are 506 samples and 14 attributes. For simplicity and visualization purposes, we will only use two — MEDV (median value of owner-occupied homes in $1000s) as the target and LSTAT (percentage of lower status of the population) as the feature.

Let’s first import the necessary attributes from scikit-learn into a pandas DataFrame.

import pandas as pd
from sklearn import datasets

boston = datasets.load_boston()            # Load Boston Dataset
df = pd.DataFrame(boston.data[:, 12])      # Create DataFrame using only the LSAT feature
df.columns = ['LSTAT']
df['MEDV'] = boston.target                 # Create new column with the target MEDV

Let’s use the DecisionTreeRegressor implemented in scikit-learn to train a regression tree:

from sklearn.tree import DecisionTreeRegressor    # Import decision tree regression model

X = df[['LSTAT']].values                          # Assign matrix X
y = df['MEDV'].values                             # Assign vector y

sort_idx = X.flatten().argsort()                  # Sort X and y by ascending values of X
X = X[sort_idx]
y = y[sort_idx]

tree = DecisionTreeRegressor(criterion='mse',     # Initialize and fit regressor
tree.fit(X, y)

Notice that our criterion is different from the one we used for our classification tree. Entropy as a measure of impurity is a useful criteria for classification. To use a decision tree for regression, however, we need an impurity metric that is suitable for continuous variables, so we define the impurity measure using the weighted mean squared error (MSE) of the children nodes instead:

Here, Nt is the number of training samples at node t, Dt is the training subset at node t, y(i) is the true target value, and ŷt is the predicted target value (sample mean):

Now, let’s model the relationship between MEDV and LSTAT to see what the line fit of a regression tree looks like:

plt.figure(figsize=(16, 8))
plt.scatter(X, y, c='steelblue',                  # Plot actual target against features
            edgecolor='white', s=70)
plt.plot(X, tree.predict(X),                      # Plot predicted target against features
         color='black', lw=2)
plt.xlabel('% lower status of the population [LSTAT]')
plt.ylabel('Price in $1000s [MEDV]')

As we can see in the resulting plot, the decision tree of depth 3 captures the general trend in the data.

I hope you enjoyed this tutorial on decision trees! We discussed the fundamental concepts of decision trees, the algorithms for minimizing impurity, and how to build decision trees for both classification and regression.

In practice, it is important to know how to choose an appropriate value for a depth of a tree to not overfit or underfit the data. Knowing how to combine decision trees to form an ensemble random forest is also useful as it usually has a better generalization performance than an individual decision tree due to randomness, which helps to decrease the model's variance. It is also less sensitive to outliers in the dataset and doesn't require much parameter tuning.

We cover these techniques in our Python Machine Learning series, as well as diving into other machine learning models such as perceptrons, Adaline, linear and polynomial regression, logistic regression, SVMs, kernel SVMs, k-nearest-neighbors, models for sentiment analysis, k-means clustering, DBSCAN, convolutional neural networks, and recurrent neural networks.

We also look at other topics such as regularization, data processing, feature selection and extraction, dimensionality reduction, model evaluation, ensemble learning techniques, and deploying a machine learning model.

You can get started here!

Posted on by:

lorrli274 profile



Data Scientist @ Next Tech

Next Tech

Helping you build what's next in tech.


Editor guide

It's an amazing explanation , plus I like the series as well with indepth explanation and details about the basics which are the most important part.