Skip to main content

Understanding Decision Trees: A Fundamental Tool in Machine Learning

Decision trees are a fundamental and versatile tool in machine learning, widely used for both classification and regression tasks. They are more powerful models that can be easily understood, interpreted, and applied in various domains. In this blog, we will explore what decision trees are, how they work, their advantages and disadvantages, and some practical applications.

The Jupyter Notebook implementation can be found here.

Table of Contents:

  1. Introduction to Decision Trees
  2. Training and Visualizing a Decision Tree
  3. Making Predictions
  4. Estimating Class Probabilities
  5. The CART Training Algorithm
  6. Gini Impurity or Entropy?
  7. Regularization Hyperparameters
  8. Regression
  9. Instability
  10. Conclusion


1. Introduction to Decision Trees

What is a Decision Tree?

At its core, a decision tree is a tree-like model that helps us make decisions or predictions based on a set of rules. Each internal node of the tree represents a feature (attribute), each branch represents a decision based on that feature, and each leaf node represents the outcome or class label. Decision trees are a supervised learning technique, which means they require labeled training data to learn from.

One of the many qualities of Decision Trees is that they require very little data preparation. They don't require feature scaling or centering at all.

How do Decision Trees work?

The construction of a decision tree follows a top-down, recursive process called "recursive partitioning." The algorithm starts with a root node that contains all the training data and selects the best feature to split the data based on some criteria, typically aiming to minimize impurity or maximize information gain.

Information gain is a concept used in decision tree algorithms to measure the reduction in uncertainty or impurity when a dataset is split based on a particular feature. A high information gain suggests that the feature is valuable for splitting the data, making it a crucial factor in deciding the structure of the decision tree.

The splitting process continues recursively until a stopping criterion is met, such as a predefined depth limit or when further splitting would now provide significant improvements. Once the tree is constructed, it can be used to make predictions for new data points by traversing the tree from the root node to a leaf node, following the path dictated by the feature values.

Advantages of Decision Trees

  1. Interpretability - You can easily visualize the decision-making process in decision trees, making it clear why a particular split was made.
  2. Non-parametric - Decision trees do not assume any specific distribution for the data. This makes them flexible and suitable for a wide range of data types.
  3. Handling Mixed Data Types - Decision trees can handle both categorical and numerical features without requiring extensive preprocessing.
  4. Feature Importance - Decision trees can help identify the most important features in a dataset, which is useful for feature selection and understanding the underlying data.
  5. Scalability - Decision trees can work well with large datasets and can be easily parallelized.

Disadvantages of Decision Trees

  1. Overfitting - Decision trees can be prone to overfitting, especially when they become deep and complex. Techniques like pruning and setting maximum tree depth are used to mitigate this issue.
  2. High Variance - Decision trees are sensitive to small changes in the training data, which can lead to different tree structures. This high variance can be reduced by using ensemble methods like Random Forest.
  3. Bias Towards Dominant Classes - If one class dominates the dataset, the decision tree may become biased towards predicting that class and may not perform well on the minority class.

Practical Applications

Decision trees find application in various domains:
  1. Medicine - Used for medical diagnosis and determining treatment plans based on patient information.
  2. Finance - Used for credit scoring, fraud detection, and investment decisions,
  3. Marketing - Helps in customer segmentation, product recommendation, and market analysis.
  4. Environmental Sciences - Used for species classification, predicting deforestation, and climate change modeling.
  5. Manufacturing - Used for quality control, fraud detection, and supply chain optimization.

2. Training and Visualizing a Decision Tree

To understand Decision Trees, let's build one and take a look at how it makes predictions.

The following code trains a DecisionTreeClassifier on the iris dataset:
from sklearn.datasets import load_iris
from sklearn. tree import DecisionTreeClassifier

iris = load_iris(as_frame=True)
X_iris = iris.data[["petal length (cm)", "petal width (cm)"]].values
y_iris = iris.target

tree_clf = DecisionTreeClassifier(max_depth=2, random_state=42)
tree_clf.fit(X_iris, y_iris)

You can visualize the trained Decision Tree by first using the export_graphviz() method to output a graph definition file called iris_tree.dot:
from sklearn.tree import export_graphviz

export_graphviz(
    tree_clf,
    out_file=str(IMAGES_PATH / "iris_tree.dot"),
    feature_names=["petal length(cm)", "petal width (cm)"],
    class_names=iris.target_names,
    rounded=True,
    filled=True
)

The Decision Tree looks like Figure 1:
Figure 1. Iris Decision Tree

Let's see how the tree represented in Figure 1 makes predictions. Suppose you find an iris flower and you want to classify it. You start at the root node (depth 0, at the top): this node asks whether the flower's petal length is smaller than 2.45 cm. If it is, then you move down to the root's left child node (depth 1, left). In this case, it is a leaf node (i.e., it does not have any child nodes), so it does not ask any questions: simply look at the predicted class for that node, and the Decision Tree predicts that your flower is an Iris setosa (class=setosa).

Now suppose you find another flower, and this time the petal length is greater than 2.45 cm. You must move down to the root's right child node (depth 1, right), which is not a leaf node, so the node asks another question: is the petal width smaller than 1.75 cm? If it is, then your flower is most likely an Iris versicolor (depth 2, left). If not, it is likely an Iris virginica (depth 2, right). It's really that simple.

3. Making Predictions

Figure 2. Decision Tree decision boundaries

Figure 2 shows this Decision Tree's decision boundaries. The thick vertical line represents the decision boundaries of the root node (depth 0): petal length = 2.45 cm. Since the lefthand area is pure (only Iris setosa), it cannot be split any further. However, the righthand area is impure, so the depth-right node split is at petal width = 1.75 cm (represented by the dashed line). Since max_depth was set to 2, the Decision Tree stops right there. If you set max_depth to 3, then the two depth-2 nodes would each add another decision boundary (represented by the dotted lines).

4. Estimating Class Probabilities

A Decision Tree can also estimate the probability that an instance belongs to a particular class k. First it traverses the tree to find the leaf node for this instance, and then it returns the ratio of training instances of class k in this node. 

For example, suppose you have found a flower whose petals are 5 cm long and 1.5 cm wide. The corresponding leaf node is the depth-2 left node, so the Decision Trees should output the following probabilities: 0% for Iris setosa (0/54), 90.7% for Iris versicolor (49/54), and 9.3% for Iris virginica (5/54). And if you ask it to predict the class, it should output Iris versicolor (class 1) because it has the highest probability. Let's check this:

5. The CART Training Algorithm

Scikit-Learn uses the Classification and Regression Tree (CART) algorithm to train Decision Trees (also called "growing" trees). 

Once the CART algorithm has successfully split the training set in two, it splits the subsets using the same logic, then the sub-subsets, and so on, recursively. It stops recursing once it reaches the maximum depth (defined by the max_depth hyperparameter), or if it cannot find a split that will reduce impurity. 

As you can see, the CART algorithm is a greedy algorithm: it greedily searches for an optimum split at the top level and then repeats the process at each subsequent level. It does not check whether or not the split will to the lowest possible impurity several levels down. A greedy algorithm often produces a solution that's reasonably good but not guaranteed to be optimal.

6. Gini Impurity or Entropy?

By default, the Gini impurity measure is used, but you can select the entropy impurity measure instead by setting the criterion hyperparameters to "entropy". The concept of entropy originated in thermodynamics as a measure of molecular disorder: entropy approaches zero when molecules are still and well-ordered. 

In Machine Learning, entropy is frequently used as an impurity measure: a set's entropy is zero when it contains instances of only one class.

So, the question is should you use Gini impurity or entropy? The truth is, most of the time it does not make a big difference: they lead to similar trees. Gini impurity is slightly faster to compute, so it is a good default.

7. Regularization Hyperparameters

Decision Trees make very few assumptions about the training data (as opposed to linear models, which assume that the data is linear, for example). If left unconstrained, the tree structure will adapt itself to the training data, fitting it very closely - indeed, most likely overfitting it. Such a model is often called a nonparametric model, not because it does not have any parameters but because the number of parameters is not determined prior to training, so the model structure is free to stick closely to the data.

In contrast, a parametric model, such as a linear model, has a predetermined number of parameters, so its degree of freedom is limited, reducing the risk of overfitting (but increasing the risk of underfitting). This is called regularization.

Pruning a decision tree is a crucial technique used to prevent overfitting and improve the model's generalization capabilities. It involves removing branches or sub-trees that do not contribute significantly to the tree's predictive accuracy. Pruning simplifies the tree by setting limits on its depth, and number of leaf nodes, or requiring a minimum number of samples per leaf. By doing this, we strike a balance between model complexity and performance, ensuring that the tree is not overly complex, making it more suitable for real-world applications
Pruning a Decision Tree

Figure 3 shows two Decision Trees trained on the moons dataset. On the left, the Decision Tree is trained with the default hyperparameters (i.e., no restrictions), and on the right it's trained with min_samples_leaf=4. It is quite obvious that the model on the left is overfitting, and the model on the right will probably generalize better.
Figure 3. Regularization using min_samples_leaf

8. Regression

Decision Trees are also capable of performing regression tasks. Let's build a regression tree using Scikit-Learn's DecisionTreeRegressor class, training it on a noisy quadratic dataset with max_depth=2:
from sklearn.tree import DecisionTreeRegressor

np.random.seed(42)
X_quad = np.random.rand(200, 1) - 0.5
y_quad = X_quad ** 2 + 0.025 * np.random.randn(200, 1)

tree_reg = DecisionTreeRegressor(max_depth=2, random_state=42)
tree_reg.fit(X_quad, y_quad)

The resulting tree is represented in Figure 4.
Figure 4. A Decision Tree for regression

This tree looks very similar to the classification tree you built earlier. The main difference is that instead of predicting a class in each node, it predicts a value. For example, suppose you want to make a prediction for a new instance with x1 = 0.6. You traverse the tree starting at the root, and you eventually reach the leaf node that predicts value=0.154. This prediction is the average target value of the 46 training instances associated with this leaf node, and it results in a mean squared error equal to 0.002 over these 46 instances.

The CART algorithm works mostly the same way as earlier, except that instead of trying to split the training set in a way that minimizes impurity, it now tries to split the training set in a way that minimizes the MSE.

Figure 6. Regularizing a Decision Tree regressor

Just like for classification tasks, Decision Trees are prone to overfitting when dealing with regression tasks. Without any regularization (i.e., using the default hyperparameters), you get the predictions on the left in Figure 6. These predictions are obviously overfitting the training set very badly. Just setting min_samples_leaf=10 results in a much more reasonable model, represented on the right

9. Instability 

As you might have seen, Decision Trees are simple to understand and interpret, easy to use, powerful, and versatile. However, they do have a few limitations.

First, Decision Trees have orthogonal decision boundaries (all splits are perpendicular to an axis), which makes them sensitive to training set rotation.

For example, Figure 7 shows a simple linearly separable dataset: on the left, a Decision Tree can split it easily, while on the right, after the dataset is rotated by 45 degrees, the decision boundary looks unnecessarily convoluted. Although both Decision Trees fit the training set perfectly, it is very likely that the model on the right will not generalize well.

Figure 7. Sensitivity to training set rotation

Decision Trees are very sensitive to small variations in the training data. For example, if you just remove the widest Iris versicolor from the iris training set (the one with petals 4.8 cm long and 1.8  cm wide) and train a new Decision Tree, you may get the model represented in Figure 8.
Figure 8. Sensitivity to training set details

Random Forests can limit this instability by averaging predictions over many trees.

10. Conclusion

Decision trees are a powerful and interpretable machine learning tool that can be used for a wide range of classification and regression tasks. While they have their advantages and disadvantages, understanding how to use and interpret decision trees is a valuable skill for data scientists and machine learning practitioners. When used appropriately and in conjunction with techniques to overcome their limitations, decision trees can be a valuable asset in predictive and decision-making.

Stay tuned for more topics on machine learning!





Comments

Popular posts from this blog

A Dive into Representational Learning and Generative Models with Autoencoders and GANs

In the ever-evolving landscape of artificial intelligence, the quest for machines to understand and generate meaningful representations of data has led to remarkable breakthroughs. Representational learning , a subfield of machine learning, explores the intricate process of learning hierarchical and abstract features from raw data. Two powerful techniques that have gained significant traction in this domain are Autoencoders and Generative Adversarial Networks (GANs).  Figure 1. Generative Adversarial Network In this blog post, we will embark on a journey to explore the fascinating world of representational learning and generative models, delving into the mechanics of Autoencoders and GANs. The Jupyter Notebook for this blog can be found here . Table of Contents: Autoencoders: Unveiling Latent Representations Efficient Data Representations Performing PCA with an Undercomplete Linear Autoencoder Stacked Autoencoders Implementing a Stacked Autoencoder Using Keras Visualizing the Reco...

Reinforcement Learning: A Journey into Intelligent Decision-Making

In the ever-evolving landscape of artificial intelligence, Reinforcement Learning (RL) has emerged as a powerful paradigm, enabling machines to learn and make decisions through interaction with their environment. Let's dive into the world of reinforcement learning without further ado. Imagine training a dog named Max using treats as positive reinforcement. When Max successfully follows a command like "sit" or "stay", the owner immediately rewards him with a tasty treat. The positive association between the action and the treat encourages Max to repeat the desired behavior. Over time, Max learns to associate the specific command with the positive outcome of receiving a treat, reinforcing the training process. Figure 1. A simple example of Reinforcement Learning Table of Contents: Understanding Reinforcement Learning Key components of RL Exploring applications of RL Policy Search Neural Network Policies Types of Neural Network Policies Evaluating Actions: The Cre...

Transformative Tales: Unleashing the Power of Natural Language Processing with RNNs and Attention Mechanisms

In the ever-evolving landscape of artificial intelligence, Natural Language Processing (NLP) has emerged as a captivating frontier, revolutionizing how machines comprehend and interact with human language. Among the many tools in the NLP arsenal, Recurrent Neural Networks (RNNs) and attention mechanisms stand out as key players, empowering models to understand context, capture nuances, and deliver more sophisticated language processing capabilities.  Let's embark on a journey into the world of NLP, where the synergy of RNNs and attention mechanisms is reshaping the way machines interpret and generate human-like text. Figure 1. An RNN unrolled through time The Jupyter Notebook for this blog can be found  here . Table of Contents: What is Natural Language Processing (NLP)? Generative Shakespearean Text Using a Character RNN Creating the Training Dataset How to Split a Sequential Dataset Chopping the Sequential Dataset into Multiple Windows Building and Training the Char-RNN Mode...