Python implement decision tree from scratch

Reading time ~11 minutes

Introduction

Why would you do this ?

After all, scikit learn already has the DecisionTreeClassifier and it works really well and is highly optimized!

Well, I can see four reasons to implement it anyway!

  • It is a good exercise if you want to learn the inner details of the decision trees
  • The DecisionTreeClassifier only supports two criterions:
    criterion{“gini”, “entropy”}, default=”gini”
    

    However, I may be willing to play with other criterions if the metric I am working with is not a standard one.

  • With a code in python that does not require any compilation, pyx files and what not, you can perform plenty of experimentations of the logic of the training tree (and given the problem, obtain a better accuracy)
  • It is fun!

Starting point

So, we will use numpy and implement the DecisionTree without the knowledge of any penalty function. This one will be provided by the user.

We will also follow the fit and predict interface, as we want to be able to reuse this class without a lot of efforts.

The algorithm

Quoting Wikipedia:

A tree is built by splitting the source set, constituting the root node of the tree, into subsets—which constitute the successor children. The splitting is based on a set of splitting rules based on classification features. This process is repeated on each derived subset in a recursive manner called recursive partitioning.

Put another way: given a dataset A and labels, find a colum and a threshold, so that the data is partitionned it two datasets. Repeat this until the whole dataset has been splitted in small datasets whose size is lower than the minimum sample size given to the algorithm. The splitting part must be performed so that the split achieves the highest improvement in terms of the chosen criterion.

Parameters can be added: the maximum depth of the tree, the minimum number of elements in a leaf, the minimum gain to achieve to decide to split or not the data…

Implementation

Imports

from sklearn.utils.validation import check_X_y
import datetime
import numpy as np


class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'
    CYAN = '\033[36m'

Ok, I imported check_X_y from scikit learn. It would be super easy to remove it in the following, but it saves a lot of debugging to use it, so I will leave it here for now.

bcolors is just a convenient class to store the colors, before printing them to the terminal.

The Tree class

class Tree:

    def __init__(self):
        self.left = None
        self.right = None
        self.data = None

A tree is just a recursive data structure, it can hold data in a node and has to children, a left and a right leaf.

There are plenty of things to know about trees in computer science, but we will only need it to store data. So this class will be enough for our purposes!

The CustomDecisionTree

Let’s decompose the work a little bit more in what follows. Our CustomDecisionTree will expose fit() and predict() and will operate on numpy arrays. Making it available for pandas DataFrame could be done as well, but let’s put it aside as it require more work and does not help to understand the algorithm used to train a decision tree.

class CustomDecisionTree:

    def __init__(self, penalty_function, max_depth=3, min_sample_size=3, max_thresholds=10,
                 verbose=False):
        self._max_depth = max_depth
        self._min_sample_size = min_sample_size
        self._max_thresholds = max_thresholds
        self._penalty_function = penalty_function
        self._verbose = verbose
        self._y = None

The constructor will need:

  • penalty_function (the criterion we will try to optimize)
  • max_depth (the depth of the tree)
  • min_sample_size (the minimum size of a sample to split it)
  • max_thresholds (the number of splits proposed per numeric value)

Storing y could have been performed later, but I like to have all the variables used by my class in the constructor.

Let’s jump to the fit method.

    def fit(self, X, y, indices=None):
        check_X_y(X, y)
        self._y = y
        self._tree = Tree()
        splitters = self._build_splitters(X)

        if indices is None:
            indices = np.arange(X.shape[0])

        if self._verbose:
            self._print("{} splitters proposed".format(len(splitters)))

        self._train(self._tree, indices, 0, splitters, 0, X, y)

        return self

Still not much done here. We make sure that X and y have compatible shapes (the check_X_y function does it for us), we store y and build the splitters.

Let’s get rid of the _print() method (it is just a habit of mine to distinguish prints from different classes with colors, I find this helpful for debugging if needed, and to monitor the execution of the algorithms).

    def _print(self, input_str):
        time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        print(bcolors.CYAN + "[CustomDecisionTree | " +
              time + "] " + bcolors.ENDC + str(input_str))

Note that we will work on indices to perform all the splits recursively. It was not necessary to pass indices to fit(), but I plan to implement a ranom forest later with this class, so we will need them!

The splitters, themselves, will be at the core of the algorithm. A splitter will simply be an index (the column index) and a threshold.

A splitter will just say: the elements of column i whose value is larger than threshold must go to the right leaf, the elements which are smaller must go to the left leaf.

    def _build_splitters(self, X):
        splitters = []

        for i, column in enumerate(X.T):
            sorted_unique_values = np.sort(np.unique(column))
            thresholds = (
                sorted_unique_values[:-1] + sorted_unique_values[1:]) / 2
            n_thresholds = len(thresholds)

            if len(thresholds) > self._max_thresholds:
                thresholds = thresholds[[round(
                    i*n_thresholds / self._max_thresholds) for i in range(self._max_thresholds)]]

            for threshold in thresholds:
                splitters.append((i, threshold))

        return splitters

The splitters are the average between sorted values for each column, subsampled so that we do not have too many splitters (a large number of splitters slows down the algorithm and provides a very limited accuracy improvement).

So, we have the fit() entry point to our interface, we briefy went throgh the splitter building part, let’s continue:

    def _split(self, splitter, indices, X):
        index, threshold = splitter
        mask = X[indices, index] > threshold
        return indices[mask], indices[~mask]

As I said, a splitter just splits the data in two subsets (represented by their indices). It should be clear enough from this method that this is exactly what is performed (with a slight help from numpy).

Now if we remember the algorithm, we need to find the best splitter at each step of the recursive splits. This is where the user defined penalty will come in:

    def _splitter_score(self, splitter, indices, X, y):
        indices_left, indices_right = self._split(splitter, indices, X)
        n_left, n_right = len(indices_left), len(indices_right)

        if n_left < self._min_sample_size:
            return -100000

        if n_right < self._min_sample_size:
            return -100000

        return (n_left * self._penalty(indices_left, y) +
                n_right * self._penalty(indices_right, y)) / \
            (n_left + n_right)

Note that the weighted mean of the penalty for a splitter is returned. If you wanted to modify it, this could be performed here.

So we have our splitters, we can, for each subset, give a score to a splitter, we are ready to implement the full train method:

    def _train(self, tree, indices, depth, splitters, current_score, X, y):
        if depth >= self._max_depth:
            tree.data = indices
        else:
            splitter_and_scores = list(
                map(lambda ns: (ns, self._splitter_score(ns, indices, X, y)), splitters))
            scores = list(map(lambda sp: sp[1], splitter_and_scores))
            if len(scores) == 0:
                tree.data = indices
                return
            max_score = max(scores)
            max_index = scores.index(max_score)
            non_trival_splitters_and_scores = list(
                filter(lambda p: p[1] != -100000, splitter_and_scores))
            non_trival_splitters = list(
                map(lambda p: p[0], non_trival_splitters_and_scores))

            best_splitter, best_score = splitter_and_scores[max_index]
            indices_left, indices_right = self._split(
                best_splitter, indices, X)

            if len(indices_left) < self._min_sample_size or \
               len(indices_right) < self._min_sample_size:
                tree.data = indices

            else:
                tree.data = best_splitter

                tree.left = Tree()
                tree.right = Tree()

                self._train(tree.left, indices_left, depth + 1,
                            non_trival_splitters, best_score, X, y)
                self._train(tree.right, indices_right, depth + 1,
                            non_trival_splitters, best_score, X, y)

If we reach max_depth, the leaf we are currently in will store the indices remaining for this leaf.

Otherwise, we find the best splitter, split the data into indices_left and indices_right (induced from this best splitter) and call _train() (recursively) twice : once on each subset. At this step, the node of the tree holds a splitter.

Note that each call to train updates the children of the tree. Once all the calls to train are executed, the tree attribute of the class contains all the splitters (for intermediate nodes) and the indices for the final nodes (the ones that could not be splitted any more).

Predictions

We have to add the methods that enable to propose predictions once the tree is trained.

    def _find_indices_for_row(self, row):
        return self._traverse_trained_tree(self._tree, row)

    def _predict_one(self, row):
        indices = self._find_indices_for_row(row)
        return np.bincount(self._y[indices]).argmax()

    def _traverse_trained_tree(self, tree, row):
        if tree.left is None:
            return tree.data
        else:
            index, threshold = tree.data
            if row[index] > threshold:
                return self._traverse_trained_tree(tree.left, row)
            else:
                return self._traverse_trained_tree(tree.right, row)

    def predict(self, X):
        return np.array(
            list(map(lambda row: self._predict_one(row), X)), dtype=int)

Note that:

np.bincount(self._y[indices]).argmax()

simply returns the most common elements of y at the selected indices. The logic of navigating the tree is performed in _traverse_trained_tree(). For each node, if it is a splitter, follow the logic of the splitter (left or right according to the comparison the threshold). If the algorithm reaches a leaf (tree.left is None), return the indices stored in the leaf.

Testing !

if __name__== "__main__":

    from sklearn.datasets import make_classification
    from sklearn.metrics import accuracy_score
    import matplotlib.pyplot as plt

    X, y = make_classification(n_samples=200, shuffle=False, n_redundant=3)
    for max_depth in [1,2,5,10,15]:
        cdt = CustomDecisionTree(accuracy_score, min_sample_size=1, max_depth=max_depth)
        cdt.fit(X, y)
        y_hat = cdt.predict(X)
        score = accuracy_score(cdt.predict(X), y)
        print("Max depth: ", max_depth, " score: ", score)

And tada! As expected, we reach a perfect accuracy if the depth is large enough!

Max depth:  1  score:  0.915
Max depth:  2  score:  0.92
Max depth:  5  score:  0.925
Max depth:  10  score:  0.975
Max depth:  15  score:  1.0

A more thorough testing would include benchmark on common datasets and a comparison to other implementations of decision trees. I will do it in a separate article.

Learning more and stay tuned

I hope you liked this reading! Any comments regarding the code or the explanations is welcome! For those who want to stay tuned, I implemented a small form to leave me your email (which won’t be used for ads nor transmitted to any third party). It is in the “Subscribe” section of the navigation menu (small square on the top left).

The next article will go from the CustomDecisionTree to a CustomRandomForest and the following one will be about more detailed tests for these newly implemented classes.

The code

from sklearn.utils.validation import check_X_y
import datetime
import numpy as np

class Tree:

    def __init__(self):
        self.left = None
        self.right = None
        self.data = None

    def __str__(self, level=0):
        ret = "\t"*level+repr(self.data)+"\n"
        for child in [self.left, self.right]:
            if child is not None:
                ret += child.__str__(level+1)
        return ret

    def custom_print(self, f1, f2, level=0):
        if self.left is None:
            ret = "\t"*level+f2(self.data)+"\n"
        else:
            ret = "\t"*level+f1(self.data)+"\n"

        if self.right is not None:
            ret = self.right.custom_print(f1, f2, level+1) + ret
        if self.left is not None:
            ret += self.left.custom_print(f1, f2, level+1)

        return ret


class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'
    CYAN = '\033[36m'


class CustomDecisionTree:

    def __init__(self, penalty_function, max_depth=3, min_sample_size=3, max_thresholds=10,
                 verbose=False):
        self._max_depth = max_depth
        self._min_sample_size = min_sample_size
        self._max_thresholds = max_thresholds
        self._penalty_function = penalty_function
        self._verbose = verbose
        self._y = None

    def fit(self, X, y, indices=None):
        check_X_y(X, y)
        self._y = y
        self._tree = Tree()
        splitters = self._build_splitters(X)

        if indices is None:
            indices = np.arange(X.shape[0])

        if self._verbose:
            self._print("{} splitters proposed".format(len(splitters)))

        self._train(self._tree, indices, 0, splitters, 0, X, y)

        return self

    def predict(self, X):
        return np.array(
            list(map(lambda row: self._predict_one(row), X)), dtype=int)

    def _penalty(self, indices, y):
        predicted = [np.bincount(y[indices]).argmax()] * len(indices)
        return self._penalty_function(y[indices], predicted)

    def _print(self, input_str):
        time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        print(bcolors.CYAN + "[CustomDecisionTree | " +
              time + "] " + bcolors.ENDC + str(input_str))

    def _find_indices_for_row(self, row):
        return self._traverse_trained_tree(self._tree, row)

    def _predict_one(self, row):
        indices = self._find_indices_for_row(row)
        return np.bincount(self._y[indices]).argmax()

    def _traverse_trained_tree(self, tree, row):
        if tree.left is None:
            return tree.data
        else:
            index, threshold = tree.data
            if row[index] > threshold:
                return self._traverse_trained_tree(tree.left, row)
            else:
                return self._traverse_trained_tree(tree.right, row)

    def _build_splitters(self, X):
        splitters = []

        for i, column in enumerate(X.T):
            sorted_unique_values = np.sort(np.unique(column))
            thresholds = (
                sorted_unique_values[:-1] + sorted_unique_values[1:]) / 2
            n_thresholds = len(thresholds)

            if len(thresholds) > self._max_thresholds:
                thresholds = thresholds[[round(
                    i*n_thresholds / self._max_thresholds) for i in range(self._max_thresholds)]]

            for threshold in thresholds:
                splitters.append((i, threshold))

        return splitters

    def _split(self, splitter, indices, X):
        index, threshold = splitter
        mask = X[indices, index] > threshold
        return indices[mask], indices[~mask]

    def _splitter_score(self, splitter, indices, X, y):
        indices_left, indices_right = self._split(splitter, indices, X)
        n_left, n_right = len(indices_left), len(indices_right)

        if n_left < self._min_sample_size:
            return -100000

        if n_right < self._min_sample_size:
            return -100000

        return (n_left * self._penalty(indices_left, y) +
                n_right * self._penalty(indices_right, y)) / \
            (n_left + n_right)

    def _train(self, tree, indices, depth, splitters, current_score, X, y):
        if depth >= self._max_depth:
            tree.data = indices
        else:
            splitter_and_scores = list(
                map(lambda ns: (ns, self._splitter_score(ns, indices, X, y)), splitters))
            scores = list(map(lambda sp: sp[1], splitter_and_scores))
            if len(scores) == 0:
                tree.data = indices
                return
            max_score = max(scores)
            max_index = scores.index(max_score)
            non_trival_splitters_and_scores = list(
                filter(lambda p: p[1] != -100000, splitter_and_scores))
            non_trival_splitters = list(
                map(lambda p: p[0], non_trival_splitters_and_scores))

            best_splitter, best_score = splitter_and_scores[max_index]
            indices_left, indices_right = self._split(
                best_splitter, indices, X)

            if len(indices_left) < self._min_sample_size or \
               len(indices_right) < self._min_sample_size:
                tree.data = indices

            else:
                tree.data = best_splitter

                tree.left = Tree()
                tree.right = Tree()

                self._train(tree.left, indices_left, depth + 1,
                            non_trival_splitters, best_score, X, y)
                self._train(tree.right, indices_right, depth + 1,
                            non_trival_splitters, best_score, X, y)


if __name__== "__main__":

    from sklearn.datasets import make_classification
    from sklearn.metrics import accuracy_score
    import matplotlib.pyplot as plt

    X, y = make_classification(n_samples=20, shuffle=False, n_redundant=3)
    cdt = CustomDecisionTree(accuracy_score, verbose=True)

    cdt.fit(X, y)

    print(cdt._tree.custom_print(str,str))

    X, y = make_classification(n_samples=200, shuffle=False, n_redundant=3)
    for max_depth in [1,2,5,10,15]:
        cdt = CustomDecisionTree(accuracy_score, min_sample_size=1, max_depth=max_depth)
        cdt.fit(X, y)
        y_hat = cdt.predict(X)
        score = accuracy_score(cdt.predict(X), y)
        print("Max depth: ", max_depth, " score: ", score)

How to optimize PyTorch code ?

Optimizing some deep learning code may seem quite complicated. After all, [PyTorch](https://pytorch.org/) is already super optimized so w...… Continue reading

Acronyms of deep learning

Published on March 10, 2024

AI with OCaml : the tic tac toe game

Published on September 24, 2023