Decision tree is general purpose prediction and classification mechanism that evolved over time to become highly cross-disciplinary, general purpose computationally intensive methods for prediction and classification, artificial intelligence, machine learning, knowledge discovery, and inductive rule-builder.
As the name suggests, a decision tree has a tree-like structure but is reversed, in which each internal node (leaves) represents a sort of question on an attribute in the simplest form, yes-or-no, and each branch will represent the outcome of the question, and each leaf node represents the answer (a class label, decision taken after computing all attributes). Thus, decision tree learning is a process of finding the optimal path from the root to the leaf representing classification rules according to the selected metric.
Python and Decision Tree
One of the best ways to create and visualize decision trees with Python is scikit-learn library, which is one of the most important libraries in the field of machine learning, it implements a lot of machine learning algorithms, like Random Forests, and a lot of classes for data processing, but move on our problem: Visualize Decision Trees.
The dataset and the model
The dataset we are going to use is the classic Iris dataset, here is the code to load it.
import sklearn.datasets as datasets from sklearn import tree from sklearn.ensemble import RandomForestClassifier import matplotlib.pyplot as plt import pandas as pd iris = datasets.load_iris() df = pd.DataFrame(iris.data, columns=iris.feature_names)
The code is quite easy to understand. The iris dataset is loaded in the variable iris in a dictionary-like structure, then a pandas dataframe is created using the data and the features coming from the iris variable.
Note: Remember, the goal here is to visualize our decision trees, thus any sort of split of the dataset in train and test set or other kinds of strategies to train the model will be executed.
Now that we have the data, we can create and train our Decision tree model.
X = df # data without target class y = iris.target # target class model = DecisionTreeClassifier() model = model.fit(X, y)
Let’s visualize it: Text Representation
There are different tools to print out a decision tree, one of these tools is the traditional text representation, useful when we are working without a user interface or when we want to log some information about the model. To print out the tree as text is necessary to import the module tree from scikit-learn and use the method export_text (documentation).
tree_text = tree.export_text(model) print(tree_text)
|--- feature_3 <= 0.80 | |--- class: 0 |--- feature_3 > 0.80 | |--- feature_3 <= 1.75 | | |--- feature_2 <= 4.95 | | | |--- feature_3 <= 1.65 | | | | |--- class: 1 | | | |--- feature_3 > 1.65 | | | | |--- class: 2 | | |--- feature_2 > 4.95 | | | |--- feature_3 <= 1.55 | | | | |--- class: 2 | | | |--- feature_3 > 1.55 | | | | |--- feature_0 <= 6.95 | | | | | |--- class: 1 | | | | |--- feature_0 > 6.95 | | | | | |--- class: 2 | |--- feature_3 > 1.75 | | |--- feature_2 <= 4.85 | | | |--- feature_0 <= 5.95 | | | | |--- class: 1 | | | |--- feature_0 > 5.95 | | | | |--- class: 2 | | |--- feature_2 > 4.85 | | | |--- class: 2
To read the output you need to use if/else logic. For example, if feature3 is less or equal to 0.80 so the decision is class 0, else if feature3 is greater than 0.80, let’s check if feature3 is less or equal to 1.75, if yes continue and verify if feature2 is less or equal to 4.95, again if the answer is yes go ahead checking if feature3 is less or equal to 1.65 if yes the label is class 1 otherwise class 2. This is, in a few words, the interpretation of the decision tree structure.
Let’s visualize it: Graphic Representation
Naturally, if we want to show our beautiful tree to other people, maybe the text is not the right way. Thus, the module tree implements also plot_tree a method to produce the figure of the tree, it requires the library of matplotlib.
plt.figure(figsize=(20, 20)) tree.plot_tree(model, feature_names=iris.feature_names, class_names=iris.target_names, fontsize=10, filled=True) plt.show()
The chart is more comprehensible than the text representation, of course, here it is possible to read the features’ names and with the parameter filled set to True the method uses colour to indicate the majority of the class.
Visualize the decision tree within our Random Forest
Random Forests are a collection of decision trees, where trees are different from each other. To build up a Random Forest in Python and scikit-learn, it is necessary to indicate the number of trees in our forest, called estimators. Each tree is totally independent of the others and each of them will make random decisions to guarantee the difference. Let’s construct and visualize trees in our forest.
import sklearn.datasets as datasets import pandas as pd from sklearn import tree import matplotlib.pyplot as plt from sklearn.ensemble import RandomForestClassifier iris = datasets.load_iris() df = pd.DataFrame(iris.data, columns=iris.feature_names) X = df # data without target class y = iris.target # target class # our forest is composed by 20 trees (estimators) model = RandomForestClassifier(n_estimators=20, random_state=42) model = model.fit(X, y) # now it is possible to access to attribute .estimators_ (a list of DecisionTree) print("Trees:", len(model.estimators_)) # to plot our trees we have to select a tree from the estimators list tree_text = tree.export_text(model.estimators_) print(tree_text) plt.figure(figsize=(30, 25)) tree.plot_tree(model.estimators_, # we are plotting the first tree feature_names=iris.feature_names, class_names=iris.target_names, fontsize=10, filled=True) plt.margins(102, 102) plt.show()
|--- feature_3 <= 0.80 | |--- class: 0.0 |--- feature_3 > 0.80 | |--- feature_3 <= 1.75 | | |--- feature_2 <= 5.40 | | | |--- feature_3 <= 1.45 | | | | |--- class: 1.0 | | | |--- feature_3 > 1.45 | | | | |--- feature_2 <= 4.95 | | | | | |--- class: 1.0 | | | | |--- feature_2 > 4.95 | | | | | |--- feature_1 <= 2.60 | | | | | | |--- class: 2.0 | | | | | |--- feature_1 > 2.60 | | | | | | |--- class: 1.0 | | |--- feature_2 > 5.40 | | | |--- class: 2.0 | |--- feature_3 > 1.75 | | |--- feature_2 <= 4.85 | | | |--- feature_1 <= 3.10 | | | | |--- class: 2.0 | | | |--- feature_1 > 3.10 | | | | |--- class: 1.0 | | |--- feature_2 > 4.85 | | | |--- class: 2.0
The important difference from the previous code is that in a Random Forest it is necessary to set the number of estimators (our trees). The random forest model has an attribute called estimators_ which is a list. From this list, it is possible to access each decision tree assembling our forest.
In conclusion, in machine learning could be very helpful to visualize how our model is behaving, it’s another way to better understand if our model is doing good or bad and how. Visualizing a single decision tree within a forest can help us to provide an idea of how an entire random forest makes predictions, you can notice that’s not random, but rather an ordered logical sequence of steps. Moreover, plots are much easier to understand for people who do not work with machine learning and these plots can help in conveying the model’s logic to the stakeholders.
You might also like
More from Data Visualisation
Plotly – Wonderful Line chart in Python
Brief Python tutorial on line chart in Plotly
Plotly – Wonderful scatterplot in Python
Easy way to create a wonderful scatter plot in Python with Plotly.