You may need to extract trees from a classifier for various reasons. In my case, I thought that the feature of xgboost ntree_limit
was quite convenient when cross validating a gradient boosting method over the number of trees.
What it does is that it only uses the first ntree_limit
trees to perform the prediction (instead of using all the fitted tree).
predict(data, output_margin=False, ntree_limit=0, pred_leaf=False,
pred_contribs=False, approx_contribs=False,
pred_interactions=False, validate_features=True, training=False)
And it is also available as an extra argument of .predict()
if you use the scikit-learn interface :
ypred = bst.predict(dtest, ntree_limit=bst.best_ntree_limit)
Indeed, by doing so, if you want to find the optimal number of trees for your model, you do not have to fit the model for 50 trees, and then predict, then fit it for 100 trees and then predict. You may fit the model once and for all for 200 trees and then, playing with ntree_limit
you can observe the performance of the model for various number of trees.
The RandomForest, as implemented in scikit-learn does not show this parameter in its .predict()
method. However, this is something we can quickly fix. Indeed, the RandomForest exposes estimators_
. You can modify it (beware, this is a bit hacky and may not work for other versions of scikit-learn).
rf_model = RandomForestRegressor()
rf_model.fit(x, y)
estimators = rf_model.estimators_
def predict(w, i):
rf_model.estimators_ = estimators[0:i+1]
return rf_model.predict(x)
And that’s it, the predict method now only looks at the first i trees ;)