Decision Tree Classifier is a simple Machine Learning model that is used in classification problems. It is one of the simplest Machine Learning models used in classifications, yet done properly and with good training data, it can be incredibly effective in solving some tasks. As we saw in a previous article, sometimes the simplest models are the best of certain tasks. So in this article, we are going to take a look at the logic and the maths behind the Decision Trees Classifiers and analyze it by looking at a simple dataset.
Interested in more stories like this? Follow me on Twitter at @b_dmarius and I'll post there every new article.
This article is part of a mini-series of two articles about Decision Trees Classifiers. This will cover only the theory and maths behind this type of classifier. If you are more interested in the implementation using Python and Scikit-Learn, please read the other article, Decision Tree Classifier Tutorial in Python and Scikit-Learn.
Introduction to Decision Trees Classifiers
In a previous article, we defined what we mean by classification tasks in Machine Learning. If you want to check that out please follow the "Classification tasks in Machine Learning" section of this article. If you've already seen that or you're familiar with classification tasks, let's see again our simple dataset that we can use better understand decision trees.
This is a simple dataset I made so that I can explain classification problems in simple terms. Basically, the goal of our classification task is to determine whether there's a traffic jam given the weather outside, the type of day(if it's a workday or a weekend day) and the time of day - morning, lunch and evening.
So the idea is that we would build a classifier that given an input X of type
X = [Clear, Workday, Morning]
would give us back an output Y of type
Y = [y1, y2]
meaning the probability that there is and that there is not a traffic jam. Then we can choose the highest yi and that is the result of our classification.
How do Decision Trees Work
Decision Trees Classifiers are a type of Supervised Machine Learning meaning we build a model, we feed training data matched with correct outputs and then we let the model learn from these patterns. Then we give our model new data that it hasn't seen before so that we can see how it performs. And because we need to see what exactly is to be trained for a Decision Tree, let's see what exactly a decision tree is.
A decision tree consists of 3 types of components:
- Nodes - Decision over a value of a certain attribute("is age over 50?", "is salary higher than $2000?")
- Edges - An edge is actually one of the answers from a node("yes", "no") and build the connection to the next nodes.
- Leaf nodes - Exit points for the outcome of the decision tree - for example, in our case, we can have multiple "Yes" and "No" leaf nodes meaning there are multiple ways we can exit the decision trees with the information that there will be or there will not be a traffic jam.
So predicting a value from decision tree would mean start from the top(the root node) and asking questions specific to each node. Then depending on the node answer, we choose the correct branch and continue with that until we arrive to a leaf node, thus finding a decision.
This looks like a bunch of if-else statements to me, where's the Machine Learning?
If by now you've arrived at this observation, then congrats to you, this is good intuition. The Machine Learning part, the juice of this algorithm is in finding the best way to build this decision tree so that when it sees new data in real life, it will know how to arrive at the correct decision(the correct leaf node).
The first step is choosing a root node, meaning choosing the feature that helps us get the greatest information gain, or choosing the greatest leap we can take towards our answer. This is done by getting every feature in our dataset, splitting the dataset into the values for that feature and observing the accuracy of our tree for each feature. We then choose the feature with the greatest accuracy and set it as our tree root.
This is like asking our dataset: if I were to have only one feature available, what is the feature that can help me get the biggest accuracy in rapport with all the others?
Then we only have to recursively apply this process to every branch of that root node. And then to every branch of our resulted nodes and so on, you get it. Basically, it is like considering every subtree a brand new tree and doing our best to get as much accuracy as possible.
Usually there are 2 criteria for stopping this recursive process:
- A node is pure, meaning all the corresponding lines in our dataset for this node have the same output for the target variable. So there's no point in further dividing this node because dataset entry here will arrive to the same conclusion. Therefore, we'll mark this node as a leaf node.
- Our tree becomes too complicated, with too many levels. It's up to you to set the maximul level for a tree and this decision should be made based on experimentation
Decision Tree Classifiers - Applications and use-cases
So we've already discussed before that there are certain use-cases where simple models perform better than more complicated models. Sometimes it's better to sacrifice just a little bit of accuracy but gain lots in terms of performance, ease of use and speed of implementation.
Decision Trees classifiers are used in classification tasks where the dataset is not huge and can be modelled by a simpler model. Also we can use this classifier when we have only a few features available or if we need a model that can be visualised and explained in simpler terms.
Usually, for datasets with lots of features, Decision Trees tend to overfit so it's better that we do a Principal Component Analysis on our dataset so that we can choose only the features which really bring value to our classification.
Seeing Decision Tree Classifier in action
This is it for the first article of our two-part mini-series on Decision Tree Classifiers. If you want to see a Decision Tree Classifier implementation in Python and Scikit-Learn, along with some fun visualisation stuff, please read Decision Tree Classifier Tutorial in Python and Scikit-Learn.
Thank you so much for reading this! Interested in more stories like this? Follow me on Twitter at @b_dmarius and I'll post there every new article.