Extract trees from a random forest in python

Reading time ~1 minute

You may need to extract trees from a classifier for various reasons. In my case, I thought that the feature of xgboost ntree_limit was quite convenient when cross validating a gradient boosting method over the number of trees.

What it does is that it only uses the first ntree_limit trees to perform the prediction (instead of using all the fitted tree).

predict(data, output_margin=False, ntree_limit=0, pred_leaf=False, 
        pred_contribs=False, approx_contribs=False, 
        pred_interactions=False, validate_features=True, training=False)

And it is also available as an extra argument of .predict() if you use the scikit-learn interface :

ypred = bst.predict(dtest, ntree_limit=bst.best_ntree_limit)

Indeed, by doing so, if you want to find the optimal number of trees for your model, you do not have to fit the model for 50 trees, and then predict, then fit it for 100 trees and then predict. You may fit the model once and for all for 200 trees and then, playing with ntree_limit you can observe the performance of the model for various number of trees.

The RandomForest, as implemented in scikit-learn does not show this parameter in its .predict() method. However, this is something we can quickly fix. Indeed, the RandomForest exposes estimators_. You can modify it (beware, this is a bit hacky and may not work for other versions of scikit-learn).

rf_model = RandomForestRegressor()
rf_model.fit(x, y)

estimators = rf_model.estimators_

def predict(w, i):
    rf_model.estimators_ = estimators[0:i+1]
    return rf_model.predict(x)

And that’s it, the predict method now only looks at the first i trees ;)

OCaml List rev_map vs map

If you found this page, you are probably very familiar with OCaml already!So, OCaml has a ````map```` function whose purpose is pretty cl...… Continue reading

How to optimize PyTorch code ?

Published on March 17, 2024

Acronyms of deep learning

Published on March 10, 2024