Decision trees and random forests

    A decision tree is is a type of machine learning algorithm that consists of a tree-like graph. Each node represents a decision designed to maximize information gain. While they are easy to understand and computationally inexpensive, they are prone to high variance. Random forests are a powerful ensemble technique that can be used to address the shortcomings of decision trees. As the name implies, a random forest consists of multiple decision trees and uses the mode or mean of the individual predictions to perform a classification or regression. Here, we review how decision trees work, introduce random forests and discuss their benefits and limitations, and also show how a random forest can be used to evaluate feature importance. Here, we focus on classification (assuming each sample can only be labeled A or ~A), and assume basic knowledge of the vocabulary of graph theory, which can be reviewed at

Decision trees

    A decision tree is a classification or regression model that is represented by a tree-like graph. This graph consists of a root node (often placed at the top of visual representations), and directed edges attached to branch nodes, ending eventually in leaf nodes. Continuing our analogy with trees found in nature we conclude that a root node has no incoming edges, and only outgoing edges attached to branch nodes (representing the decision made at that node). Branch nodes each have one incoming edge; while leaf nodes have no outgoing edges.


To determine if a node is a leaf node we can check the entropy, I of a node on (a subset of ) the training examples, S.


Where p(+) and p(-) are the number of positive versus negative samples in S. We can interpret this by asking that if X ∈ S, then how many bits (not a computer bit, this values ranges between zero and one) are necessary to classify X. For example, if S consist of all positive elements then you need zero bits to determine if X is positive, and S is a leaf node. Conversely, if exactly half of the elements of S are positive then you need an entire bit to determine the class of X. One benefit of the entropy metric is symmetry; if we were to switch classes, I(S) would still be 0 for a pure node, and still be 1 for an impure node.

    Alternative impurity metrics include Gini impurity (1- p(+)2-p(-)2) and misclassification (1-max(p(+), p(-));readers are encouraged to determine if these metrics are still symmetric.

    At the root and branch nodes, the model decides which class to split on by maximizing the information gain.


Where I(Dp) is one of the previously measured impurity metrics; Np, Nleft, Nright are the number of samples in the parent/left-split/right-split nodes; and Dp, Dleft, Dright are the datasets of the parent/left-split/right-split nodes. This formula can be interpreted as subtracting the weighted average of child impurity from parent impurity. An optimal split would maximize information gain by rewarding splits that result in very pure subsets with a very large number of samples. This also mitigates the decision (in models with 3+ classes) to prioritize splitting data into very small subsets of high purity which would result in deeper, more memory intensive trees.

Lastly, once we choose an impurity metric define information gain IG, we need to choose how to split the data at any given parent (root or branch) node. One way this is naively accomplished by a recursive brute force loop that calculates IG for various splits and chooses the split that maximizes IG.

We have just discussed how to build a decision tree by calculating impurity of a given node, and maximizing information gain at subsequent nodes. Decision trees are intuitively understood by non-technical people, more robust than simple linear models, implicitly perform feature selection, and are log(N) in cost per training example. They are however sensitive to small perturbations in data, prone to overfitting, can not learn certain problems (such as XOR), and prone to bias in unbalanced datasets. Some of these drawbacks can be addressed with random forests, which are collections of decision trees.

Random forests

While decision trees are intuitive to present graphically, they are prone to many shortcomings. Random forests address these techniques by combining the result of many decisions trees trained on random subsets of training data. For classification, each tree contributes (sometimes with a weight) to a mode which is selected as the output class. Here, we discuss the benefits and drawbacks of random forests, how bagging works in random forests, and discuss a technique for estimating feature importance.

Random forest is a machine learning ensemble technique that uses multiple decision trees to classify (or regress) data. While decision trees are prone to high variance, random forests can help tighten that up. Random forest classifiers can be memory intensive, however since each decision tree is independent of others; distributed computing can address issues related to memory and speed. An additional benefit of random forests is that it allows for estimation of feature importance.

Bagging in a random forest is different from traditional bagging. While data are randomly selected (with replacement) for each tree, features are also randomly selected for each split. This improves training time, and since multiple trees are being used, there is no concern about ignoring features (dropped features are likely to get picked up by subsequent nodes, and by other trees).

 One final benefit of the random forest algorithm that we will discuss is its ability to help estimate feature importance. This is accomplished using many, shallow forests. Since for any given decision tree each split maximizes information gain, the first split will provide the most information gain. Repeating the first splitting process for many trees will result in a list of decisions that maximize information gain. Recall, these decisions are based of off features in the data, so we can extract a list of features the influence the initial split. The code to accomplish this is shown below.

from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier

clf = RandomForestClassifier(n_estimators = 100,max_depth=1)
iris = load_iris()
clf =,

#note, due to randomness, your results may be different!

>>[ 0.19  0.    0.37  0.44]
>>petal width (cm)