Extract trees from a random forest in python

Reading time ~1 minute

You may need to extract trees from a classifier for various reasons. In my case, I thought that the feature of xgboost ntree_limit was quite convenient when cross validating a gradient boosting method over the number of trees.

What it does is that it only uses the first ntree_limit trees to perform the prediction (instead of using all the fitted tree).

predict(data, output_margin=False, ntree_limit=0, pred_leaf=False, 
        pred_contribs=False, approx_contribs=False, 
        pred_interactions=False, validate_features=True, training=False)

And it is also available as an extra argument of .predict() if you use the scikit-learn interface :

ypred = bst.predict(dtest, ntree_limit=bst.best_ntree_limit)

Indeed, by doing so, if you want to find the optimal number of trees for your model, you do not have to fit the model for 50 trees, and then predict, then fit it for 100 trees and then predict. You may fit the model once and for all for 200 trees and then, playing with ntree_limit you can observe the performance of the model for various number of trees.

The RandomForest, as implemented in scikit-learn does not show this parameter in its .predict() method. However, this is something we can quickly fix. Indeed, the RandomForest exposes estimators_. You can modify it (beware, this is a bit hacky and may not work for other versions of scikit-learn).

rf_model = RandomForestRegressor()
rf_model.fit(x, y)

estimators = rf_model.estimators_

def predict(w, i):
    rf_model.estimators_ = estimators[0:i+1]
    return rf_model.predict(x)

And that’s it, the predict method now only looks at the first i trees ;)

Add arguments to Python decorators

Python decorator are a convenient way to wrap a function with another one. Per example, when timing a function, it is nice to call the ti...… Continue reading

Unique elements in a list OCaml

Published on June 18, 2023

List intersection in OCaml

Published on June 18, 2023