Extract trees from a random forest in python

Reading time ~1 minute

You may need to extract trees from a classifier for various reasons. In my case, I thought that the feature of xgboost ntree_limit was quite convenient when cross validating a gradient boosting method over the number of trees.

What it does is that it only uses the first ntree_limit trees to perform the prediction (instead of using all the fitted tree).

predict(data, output_margin=False, ntree_limit=0, pred_leaf=False, 
        pred_contribs=False, approx_contribs=False, 
        pred_interactions=False, validate_features=True, training=False)

And it is also available as an extra argument of .predict() if you use the scikit-learn interface :

ypred = bst.predict(dtest, ntree_limit=bst.best_ntree_limit)

Indeed, by doing so, if you want to find the optimal number of trees for your model, you do not have to fit the model for 50 trees, and then predict, then fit it for 100 trees and then predict. You may fit the model once and for all for 200 trees and then, playing with ntree_limit you can observe the performance of the model for various number of trees.

The RandomForest, as implemented in scikit-learn does not show this parameter in its .predict() method. However, this is something we can quickly fix. Indeed, the RandomForest exposes estimators_. You can modify it (beware, this is a bit hacky and may not work for other versions of scikit-learn).

rf_model = RandomForestRegressor()
rf_model.fit(x, y)

estimators = rf_model.estimators_

def predict(w, i):
    rf_model.estimators_ = estimators[0:i+1]
    return rf_model.predict(x)

And that’s it, the predict method now only looks at the first i trees ;)

Best books on Fermat's last theorem

*This article was updated on 26th february 2023 to add a summary and new books.*# IntroductionI haven't published many articles these day...… Continue reading