DEV Community

loading...

Simple DecisionTreeClassifier in Python

daviducolo profile image Davide Santangelo ・2 min read

Decision Trees (DTs) are a non-parametric supervised learning method used for classification and regression. The goal is to create a model that predicts the value of a target variable by learning simple decision rules inferred from the data features.

For instance, in the example below, decision trees learn from data to determine the preferred music genre based on the year and gender of the person. The deeper the tree, the more complex the decision rules and the fitter the model.

DecisionTreeClassifier is a class capable of performing multi-class classification on a dataset.

As with other classifiers, DecisionTreeClassifier takes as input two arrays: an array X, sparse or dense, of size [n_samples, n_features] holding the training samples, and an array Y of integer values, size [n_samples], holding the class labels for the training samples:

from sklearn import tree
import sys

# age
# sex [0: male, 1: female]

features = [
    [18, 0], [19, 0], [22, 0], [25, 0], [28, 0], [31, 0], [34, 0], [40, 0], [45, 0],
    [18, 1], [19, 1], [22, 1], [25, 1], [28, 1], [31, 1], [34, 1], [40, 1], [45, 1]
]

# music genre

labels = [
    'rap', 'rap', 'hip hop', 'hip hop',
    'rock', 'rock', 'rock', 'country', 'country',
    'dance', 'dance', 'hip hop', 'hip hop',
    'rap', 'rap', 'rap', 'classical', 'classical'
]

clf = tree.DecisionTreeClassifier()

clf.fit(features, labels)

# pass age and sex as script params with sys.argv
prediction = clf.predict([[sys.argv[1], sys.argv[2]]])

print(prediction)


Enter fullscreen mode Exit fullscreen mode

Try it!


python3.7 decision_tree_classifier.py 18 1
['dance']

Enter fullscreen mode Exit fullscreen mode

the tree can also be exported in textual format with the function export_text.


from sklearn.tree.export import export_text

decision_tree_text = export_text(clf, feature_names=['age', 'sex'])
print(decision_tree_text)

Enter fullscreen mode Exit fullscreen mode

|--- age <= 37.00
|   |--- age <= 26.50
|   |   |--- age <= 20.50
|   |   |   |--- sex <= 0.50
|   |   |   |   |--- class: rap
|   |   |   |--- sex >  0.50
|   |   |   |   |--- class: dance
|   |   |--- age >  20.50
|   |   |   |--- class: hip hop
|   |--- age >  26.50
|   |   |--- sex <= 0.50
|   |   |   |--- class: rock
|   |   |--- sex >  0.50
|   |   |   |--- class: rap
|--- age >  37.00
|   |--- sex <= 0.50
|   |   |--- class: country
|   |--- sex >  0.50
|   |   |--- class: classical


Enter fullscreen mode Exit fullscreen mode

Discussion (0)

Forem Open with the Forem app