Add arguments to Python decorators

Python decorator are a convenient way to wrap a function with another one. Per example, when timing a function, it is nice to call the timer before and after without having to rewrite the same timer.start() and timer.stop() calls everywhere.

Among other nice features of decorators, there are a couple of them which I want to stress here:

  • they apply to classes,
  • they accept arguments.

Here, I want to implement a printer for my classes showing exactly which class has called some of its methods and when. Ideally, the colors should be specified for each class, so that it makes it easy to monitor the execution of the code when various objects called different methods.

This is something I often face when dealing with machine learning pipelines: they consist of many objects (blend of models, feature pipelines) and it is hard to know exactly who is running an when, and identify bottlenecks at a glance.

Class decorators

Applying a decorator to a class allows to add method on classes directly, a little bit like inheritance. It drove my puzzled: why would one prefer a class decorator to inheritance ? See per example this question. The thing is that inheritance actually has a meaning in terms of your object: a cat is an animal, therefore, inheritance makes sense. On the other hand, if you want to add generic helpers that do not really fit into your object definitions, a class decorator is preferable.

import datetime

def class_printer(cls):

      def print_with_time(self, s):
          now_str ="%m/%d/%y, %h:%m:%s")
          print(f"[{type(self).__name__} - {now_str}] {s}")

      setattr(cls, 'print_with_time', print_with_time)

      return cls

class a:

    def __init__(self):

    def run(self):
        self.print_with_time("run method called")

if __name__ == "__main__":
    a = a()

And it outputs:

[A - 03/08/2022, 11:20:44] run method called

Great! Now, every time we need this specific printer, we don’t have much to do. Time to improve it!

Passing arguments to a decorator


import datetime

class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKCYAN = '\033[96m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'

def color_class_printer(color):
    def class_printer(cls):

        def print_with_time(self, s):
            now_str ="%m/%d/%Y, %H:%M:%S")
            print(f"[{color}{type(self).__name__} - {now_str}{bcolors.ENDC}] {s}")

        setattr(cls, 'print_with_time', print_with_time)

        return cls
    return class_printer
Unique elements in a list OCaml

Unique elements in a list in OCaml

First, one should note that there is a function List.sort_uniq which is proposed in OCaml. However, it requires the implementation of a comparison function (like > for integers of float numbers) and this comparison it not always obvious to implement (think about a list of arrays per example).

Instead, the following snippet will do the job!


let unique l =

  let rec aux l acc =
    match l with
    | [] ->
        List.rev acc
    | h :: t ->
        if List.mem h acc then aux t acc else aux t (h :: acc)
  aux l []

And we can run it in the toplevel:

# let unique l =

  let rec aux l acc =
    match l with
    | [] ->
        List.rev acc
    | h :: t ->
        if List.mem h acc then aux t acc else aux t (h :: acc)
  aux l []                  ;;
val unique : 'a list -> 'a list = <fun>

# unique [1;2;3] ;;
- : int list = [1; 2; 3]

# unique [1;2;3;2;4;1];;
- : int list = [1; 2; 3; 4]

Why the List.rev

The List.rev may be seen as a useless operation. It just makes sure that the elements are returned in their order of first occurence.

In order to be closer to the standard terminology used in the language, one could maybe implement two functions: unique and stable_uniq.

let unique l =

  let rec aux l acc =
    match l with
    | [] ->
    | h :: t ->
        if List.mem h acc then aux t acc else aux t (h :: acc)

let stable_unique l =

  let rec aux l acc =
    match l with
    | [] ->
        List.rev acc
    | h :: t ->
        if List.mem h acc then aux t acc else aux t (h :: acc)
  aux l []

List.mem and List.memq

Note that I used List.mem but you should be aware of the following differences (from the docs):

val mem : ‘a -> ‘a list -> bool

mem a set is true if and only if a is equal to an element of set.

val memq : ‘a -> ‘a list -> bool

Same as List.mem, but uses physical equality instead of structural equality to compare list elements.

List intersection in OCaml

Intersection of two lists in OCaml

List intersection is a simple operation that is often needed, unfortunately, it is not implemeted directly in OCaml (at least, in 4.12). Note that with large collections, sets are a better container for intersections. However, when you know you are working with a limited number of items and that performance should not be an issue, it is perfectly fine to intersect lists.

The following one liner will do the job:

let intersection l1 l2 = List.filter (fun e -> List.mem e l2) l1 

That’s it. ````List.mem``` makes sure that an element is a __mem__ber of a list, and we filter according to this condition.

You can test it in the toplevel! I recommend using it for these simple functions. It does not replace a proper unit testing, but if you are working quickly on a prototype, nobody will blame you ;)

#  let intersection l1 l2 = List.filter (fun e -> List.mem e l2) l1  ;;
val intersection : 'a list -> 'a list -> 'a list = <fun>
# intersection [1;2;3] [4] ;;
- : int list = []
# intersection [1;2;3] [1] ;;
- : int list = [1]

Intersection of a list of list

Let’s leverage the ````fold_left``` function ! It makes sense to define the intersection of a single list as the list itself.

let lists_intersection lists = match lists with
  | [] -> []
  | [h] -> h
  | h::t -> List.fold_left intersection h t

And we can test it quickly as well:

# let lists_intersection lists = match lists with                     
  | [] -> []
  | [h] -> h
  | h::t -> List.fold_left intersection h t      ;;
val lists_intersection : 'a list list -> 'a list = <fun>
# lists_intersection [[1;2;3]; [1]];;   
- : int list = [1]
# lists_intersection [[1;2;3]; [1;2]; [1;2]];;
- : int list = [1; 2]

It works!

Best books on Fermat's last theorem

This article was updated on 26th february 2023 to add a summary and new books.


I haven’t published many articles these days, the main reason being that I got attracted into the history of Fermat’s last theorem. This theorem states that the equation:

Only has trivial integer solutions (i.e. one of the elements is 0) for . This problem fascinated mathematicians for more than three centuries before it was finally solved! Indeed, it is easy to see that for , the triple does the job.

Regarding the “Maths contents”, do not be scared, all books are a good fit for motivated undergrads in mathematics. Those with one star are even more accessible.

Book Maths contents Topics Interest
Invitation to the Mathematics of Fermat-Wiles ** Approach to Wiles proof without leaving historical approaches ***
Fermat’s Last Theorem: A Genetic Introduction to Algebraic Number Theory ** Exposes the theory of algebraic integers, Kummer’s proof for regular prime and many historical details ***
Fermat’s Last Theorem for Amateurs * Many partial historical results related to FLT presented in great details. ***
13 Lectures on Fermat’s last theorem ** A lot of partial historical results related to FLT presented without details. **
Fermat a-t-il démontré son grand théorème ? L’hypothèse « Pascal » (in French) * A historical speculation around a possible proof of Fermat and a detailed analysis of his correspondance. **
Fermat last theorem None General introduction and history **
Le théorème de Fermat : son histoire by E. Nogues ** A history of research around FLT, written at the beginning of the 20th century! *


Top 3

We can decompose the history of the main approaches of Fermat’s last theorem in three epochs: the arithmetical one (Sophie Germain, Cauchy, Legendre, Pellet…), the algebraic integers epoch (Kummer, Mirimanoff…) and the modular form (Wiles) epoch. Obviously, this misses many other approaches (sieves, analysis…) but it covers a large part of the historical publications.

These three books would cover these epochs and I strongly recommended all three of them to have a beautiful journey into the study of FLT.

Invitation to the Mathematics of Fermat-Wiles by Yves Hellegouarch

This book aims to present the proof given by Wiles. Obviously, some details will not be presented, but this is an amazing introduction. Topics presented contains introction topic (the case ), Kummer’s proof… Soon enough, the author jumps to Elliptic Functions, Elliptic Curves and Modular Forms which are essential.

It contains a lot of exercises and clear proofs, if there is only one book to read on the topic (and you are not afraid of mathematical details), this is probably the best one to read!

Fermat’s Last Theorem: A Genetic Introduction to Algebraic Number Theory

This one is a detailed study of Kummer’s approach. Ribenboim books only scratch the surface of this proof while here, (almost) the whole book is dedicated to give a detailed proof of Kummer’s result. Some beautiful relations between Fermat’s equation, the Zeta function and Bernoulli numbers are prestend in a very clear way, with various numeric examples. Besides, the author also dug into the archives of the French academy of science, yielding thrilling historical notes about sessions where members wrongly speculated about their possible proofs.

Fermat’s Last Theorem for Amateurs by Paulo Ribenboim

This book presents many special cases of proofs of FLT with many details. Though the original problem is to find solutions in the set of integers, the author propose to study this equation in other sets as well: Gaussian integers (), -adic numbers (they truely are a hidden gem of number theory)… As the title states, it is for amateurs though this is presented as your usual math book: with theorem and proofs. Note that however, an emphasis is put on examples.

Other books

13 Lectures on Fermat’s last theorem by Paulo Ribenboim

This one is interesting: it has been written in 1979, before Fermat’s last theorem was actually proven. It summarizes many of the efforts put into trying to prove FLT and features some of the proofs. The historical context is essential, as many of the theorems are presented with notes regarding their history.

A large part of the book is devoted to the fascinating proof by Kummer and its limitations. Many other results with various importance are also presented: estimates, equivalence of FLT with other statements… One I found particularly interesting is:

  1. The equation has only the trivial solutions in
  2. For every non-zero, the polynomial is irreducible over

A minor regret is that many of the statements are not proved and left “as exercises” with no clue regarding their level of difficulty nor hints. It is particularly interesting to read it with Fermat’s Last Theorem for Amateurs as some of the proofs not present in this book are in the other.

However, it is remains an amazing introduction if you do not want to dig into all the details but instead are just looking for a nice overview of the history, with some details.

Fermat last theorem by Simon Singh

Requires absolutely no mathematical background. It is a very nice introduction to the problem, mostly focusing on the history of (and around) Fermat’s last theorem. Some mathematical details, along with various interesting puzzles are presented in the text and in the appendix. Really nice if you want to know what this theorem is about, and why so important it became, without digging into the mathematical details.

Fermat a-t-il démontré son grand théorème ? L’hypothèse « Pascal » (in French) by Laurent Hua and Jean Rousseau

This book is split in two parts: the first one is purely historical and does not require to know much about mathematics. The authors present the possibility that Fermat may have proved his theorem. Though the consensus seems to claim the opposite, proving than Fermat did not prove his theorem has never been done!

The first part is a thorough analysis of Fermat’s correspondance with other mathematicians of his time and it does become moving, especially towards the end. This analysis is so detailed that you even get a sense of Fermat’s sense of humor in some of the letters presented!

The second one is more mathematical, but a high school student could understand most of it. It gives what could be the starting point of a proof with the tools known by Fermat at his time. Unfortunately, it does not go too far, but the approach is interesting!

Le théorème de Fermat : son histoire by E. Nogues

Written at the beginning of the 20th century. It consists in translatiobns (in French) of all important articles regarding Fermat’s last theorem at this time. It is particularly intersting if you want to dig into the details of the history of the theorem.

Extra readings

In some of the books, you will need (on top of usual algebra, basic number theory and analysis tools) to know more advanced topics, such as Galois theory. To that end Galois Theory Through Exercises by Juliusz Brzeziński was my best read on this topic. It has a chapter completely dedicated to cyclotomic fields, which are an object commonly used in the books above!

Number Theory 1: Fermat’s Dream, by Kazuya Kato, Nobushige Kurokawa, Takeshi Saito this is actually the book that drove me into the study of Fermat’s Last Theorem. It is really well written, with a lot of figures, exercises and corrections.

Not read yet

The books below are my to read list. If you have any opinion on them, let me know in the comments!

Three Lectures on Fermat’s Last Theorem by Louis Joel Mordell

Mordell is a name that I now met many times! I want to read it mostly for its historical value, and to know what he had to say about this theorem, a century and a half (almost) ago!

A Course in Arithmetic by Jean-Pierre Serre

Please note that all the links above are affiliate links. However, having read these books, I am confident about the quality of my recommendations!

Benchmark Fossil Demand Forecasting Challenge


Zindi is hosting the Fossil Demand Forecasting Challenge, where competitors have to predict the amount of units sold for various products.

Note that the rules state that the metric to optimize is not is usual squared error, but instead, the absolute error:

The evaluation metric for this challenge is Mean Absolute Error.

All the models relying on the minimization of least squares (usual regressions, random forests with default parameters) are likely to perform poorly since they will return the mean over subsambles, while minimizing the absolute error returns the mean of the sample.

In a mathematical language:

A simple benchmark

With that knowledge, the benchmark below simply returns, for each product, the median of units sold over the year 2021. The score should be around 192xxx

import numpy as np
import pandas as pd
import random

train = pd.read_csv("../raw_data/Train.csv")
sku_names = train["sku_name"].unique()
train["year_month"] = train["year"].astype(
    str) + "/" + train["month"].astype(str)
train["date"] = pd.to_datetime(train["year_month"])
train_recent = train[train["date"] >= "2021/01"]

medians = train_recent.groupby("sku_name")["sellin"].median().to_dict()

test = pd.read_csv("../raw_data/Test.csv")
sku_names_test = test["sku_name"].unique()

missing = {}
for sku_name_test in sku_names_test:
    missing[sku_name_test] = 0

test["Target"] = test["sku_name"].replace(medians).replace(missing).astype(int)

test["Item_ID"] = test["sku_name"] + "_" + \
    test["month"].astype(str) + "_" + test["year"].astype(str)
test[["Item_ID", "Target"]].to_csv("./submission_.csv", index=False)
Minimize L1 penalty with (univariate) linear regression


Most of the regression problems I dealt with focused on minimizing the L2 norm of the difference between the predictions of a model and the true value. However, from a business point of view, it may make more sense to minimize directly the L1 norm.

Indeed, using a least square approach will mainly penalize large errors. But if the cost of an error is constant (once again, from the business perspective), then, the L1 penalty makes sense as well.

However, note that the estimator may be vastly different, if we want to use a constant model, the values of the intercept differ, one being the mean, the other one, the median.

More formally:

So far, so good. However, things become complex quite quickly: a theoretical advantage of L2 penalty is that it makes the penalty differientiable, and we enjoy many closed formulas for various problems relying on L2 penatlies.


I will focus on implementing the univariate case, without intercept. Including intercept or multivariate case relies on a (much) more complex optimization.

Calling we note that is a convex function, whose derivative is not continuous.

It derivative, where it is defined, is:

Given that is convex, must be monotonic. Besides and .

Therefore, we will look for a zero of using dichotomy.


As detailed above, one we have the penalty and the dichotomy algorithm implemented, there is nothing else to do:

import numpy as np

def dichotomy(n, f, l, r):

    c = (l + r) / 2
    if n ==0:
        return c

        if f(c) < 0 :
            return dichotomy(n-1, f, c, r)
            return dichotomy(n-1, f, l, c)

def penalty(x, y, b):
    return -np.sum(x * np.sign(y - b * x))

class L1Regressor:

    def __init__(self, n_iterations=20):
        self.n_iterations = n_iterations
        self.b = None

    def fit(self, x, y):
        ratios = y / x
        l, r = np.min(ratios), np.max(ratios)
        self.b = dichotomy(self.n_iterations, lambda b: penalty(x, y, b), l, r)
        return self


If we append the following:

if __name__ == "__main__":

    import matplotlib.pyplot as plt

    x = np.random.rand(100)
    y = x * 0.5 + 0.1 * np.random.normal(size=100)

    slope = L1Regressor().fit(x, y).b

    plt.scatter(x, y)
    plt.plot(x, x*slope, c = 'r')

We obtain:

L1 regression

Comparison with L2

We can add some noise to the observations and we expect to have a more robust regression with L1 penalty.

Below are plotted the two slopes: in red, the L1 penalty is minimized, in green, the L2 penalty.

L1 regression vs L2

    from sklearn.linear_model import LinearRegression

    x = np.random.rand(100)
    y = x * 0.5 + 0.1 * np.random.normal(size=100)

    x[:10] = 0.9 + 0.1 * np.random.rand(10)
    y[:10] = 2 + x[:10] * 0.5 + np.random.normal(size=10)

    slope = L1Regressor().fit(x, y).b
    slopel2 = LinearRegression(fit_intercept=False).fit(x.reshape(-1,1), y).coef_[0]

    plt.scatter(x, y)
    plt.plot(x, x*slope, c = 'r')
    plt.plot(x, x*slopel2, c = 'g')
Random number generation in Cython


In one of my programs, I had to perform (a lot of) random sampling from Python lists. So much that it ended up being my bottleneck.

Without going in too much details, the function was mostly generating random numbers and accessing elements in lists. I gave it a simple Cython try with something along these lines:

import random
import cython

def sample(int n_sampling, l):
    a = []
    for _ in range(n_sampling):
    return a

def rando(int n_sampling, l):
    a = []
    for _ in range(n_sampling):
    return a

Cython usually needs the following setup code:

from setuptools import setup
from Cython.Build import cythonize

    ext_modules = cythonize("fast_sampler.pyx")

And the following command will build the code:

python build_ext --inplace

The timing results:

python -mtimeit -s"import fast_sampler" "fast_sampler.sample(10000,[1,2,3])"
5000 loops, best of 5: 61.8 usec per loop

Accessing elements in a loop seems quick enough.

python -mtimeit -s"import fast_sampler" "fast_sampler.rando(10000,[1,2,3])"
50 loops, best of 5: 5.77 msec per loop

However, the calls to random.randrange() seem to be the bottleneck.

If we add this cimport statement, we can directly call rand()

from libc.stdlib cimport rand

def rando_c(int n_sampling, l):

    a = []
    for _ in range(n_sampling):
        a.append(rand() % 3)
    return a

And finally:

python -mtimeit -s"import fast_sampler" "fast_sampler.rando_c(10000,[1,2,3])"
2000 loops, best of 5: 104 usec per loop

Which brings a 50x speedup!

What about the seed ?

Usually, it is a good practice to add these lines in a Python code when dealing with random number generation (to ensure reproducibility):


By default, rand() always returns the same numbers, as long as you do not call srand() before, so you do not have to worry about them any more ! (At least, not in this part of your code).

Vim for datascience

There are plenty of tutorials here and there to have Python and vim interact beautifully. The point of this one is to provide some simple lines to add to you .vimrc file without to worry too much about installing more (vim) packages. Having myself struggled to implement this line, I will provide some explanations about the meaning of each of them.

If you have more tricks for your common datascience tasks in vim, let me know, I will add them!



Here are the thing you can do with the following settings:

  • Associate the common import to keypresses,
  • Preview the contents of a csv file in a vim pane,
  • Format JSON files with jq or python files with autopep8,
  • Quickly add print() statements,
  • Fix the usual copy paste into vim (bonus).

If you are familiar with vim, you will know that you can do pretty much everything with a sequence of keypresses. Recording this keypresses and mapping them to another key just factors everything you want to do ;)


Python packages: pandas, autopep8, numpy Packages: jq.

Data preview

The function in action

I start with the hardest but most satisfying command:

autocmd FileType python map <C-F9> va"y:!python -c "import pandas as pd; df=pd.read_csv('<C-R>"', nrows=5); print(df)" > tmp.tmp<CR>:sv tmp.tmp<CR>:resize 8<CR>

It will show the first five lines of the .csv file in the quotes surrounding the cursor in a new vim pane.


autocmd FileType python is just saying that the mapping which follows will only apply to python files. This avoids accidental application to other languages.

map <C-F9> means map Ctrl + F9 to the following sequence of keypresses

va"y is a way to tell vim :

  • select v

  • around a

  • quotes "

  • copy y (to register)

:! allows to execute vim commands in your actual terminal

python -c "import pandas as pd; df=pd.read_csv('<C-R>"', nrows=5); print(df)" now we are doing one line python, the only trick here is the <C-R> which refers to vim clipboard (or register), so what we stored when “pressing” va"y.

> tmp.tmp<CR>:sv tmp.tmp<CR>:resize 8<CR> outputs the Python print statement to a tmp file (tmp.tmp) which in turn is opened by vim (with :sv)

Beautifying files


This one needs autopep8 installed. Otherwise, it will just remove everything in the file you are editing…

autocmd FileType python map <F4> :!autopep8 --in-place --aggressive %<CR>

It will format your Python scripts using the autopep8 guidelines.


This one needs to have jq installed. It is a tool to manipulate JSON files easily and I strongly recommend using it.

autocmd FileType json map <F4> :%! jq .<CR>

Upon pressing <F4> it will ident your file beautifully.



If I want to execute quickly the script I am working on, these two lines enable to do it (whether I am in visual or edit mode)

autocmd FileType python map <F5> :wall!<CR>:!python %<CR>
autocmd FileType python imap <F5> <Esc>:wall!<CR>:!python %<CR>

It is ideal when you are used to test your classes like this:

from collections import defaultdict

class MarkovLikelihood:

    def __init__(self, alpha):
        self.alpha_ = alpha
        self.transition_counts_ = defaultdict(lambda: 0)
        self.start_counts = defaultdict(lambda: 1)

    def fit(self, sentences):
        for sentence in sentences:
        return self
    def update_(self, sentence):
        words = sentence.split(' ')
        for w1, w2 in self.pairwise_(words):
            self.transition_counts_[f"{w1}_{w2}"] += 1
            self.start_counts[w1] += 1

    def pairwise_(self, iterable):
        a = iter(iterable)
        return zip(a, a)

    def predict(self, sentence):
        res = 1
        words = sentence.split(' ')
        n = len(words)
        for w1, w2 in self.pairwise_(words):
            res *= (self.transition_counts_[f"{w1}_{w2}"] + self.alpha_) / self.start_counts[w1]

        return res
if __name__ == "__main__":

    ml = MarkovLikelihood(0.5)
    sentences = [ 
        "I ate dinner.",
        "We had a three-course meal.",
        "Brad came to dinner with us.",
        "He loves fish tacos.",
        "In the end, we all felt like we ate too much.",
        "We all agreed; it was a magnificent evening."]

    res = ml.predict("dinner with tacos")
    res = ml.predict("I love tennis")


The following two lines allow to have the most common imports with a couple of keypresses:

autocmd FileType python map <C-F10> ggiimport pandas as pd<CR>import numpy as np<CR>np.random.seed(0)<CR><Esc>
autocmd FileType python map <C-F11> ggiimport matplotlib.pyplot as plt<CR>import seaborn as sns<CR><Esc>

Will add the following to the Python file you are working on. Note that gg makes sure to place the cursor at the top of the file first.

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

Quick print

autocmd FileType python map <C-p> viwyoprint(<Esc>pA)<Esc>

if your cursor is on a word (my_variable), it will simply append a print(my_variable) statement below the current line. Useful for debugging.

Fixing copy and paste

" Automatically set paste mode in Vim when pasting in insert mode
let &t_SI .= "\<Esc>[?2004h"
let &t_EI .= "\<Esc>[?2004l"

inoremap <special> <expr> <Esc>[200~ XTermPasteBegin()

function! XTermPasteBegin()
  set pastetoggle=<Esc>[201~
  set paste
    return ""
Does gradient boosting overfit

What is overfit ?

Overfit is somehow what happens when you “train your model too much”. In that case, you achieve a very good training accuracy, while the test accuracy is usually poor. Think about a one dimensional regression. You may either fit a straight line, or draw a line that passes through every point. Which one will generalize best ?

In the cases of some models, there are some specific parameters that can control the balance between “passing through every point” and “drawing a straight line”. In the case of gradient boosting, this will be the number of trees, in the case of neural networks, it will be the number of iterations in the gradient descent, in the case of support vector machines, a combination of parameters…

A quite common figure regarding overfit is the following:


It applies mostly to neural network, where the abscissa represents the number of epochs, the blue line the training loss and the red line the validation loss.

The same question applies to gradient boosting, where the number of trees if quite critical and could replace the abscissa on the upper graph, see per example this question. Some people seem to claim that it should not overfit (as random forest do not overfit).


A cool python library

While looking for bencharmk data, I found pmlb which stands for Penn Machine Learning Benchmark. Basically it enables to download different real word datasets.

More info can be found in the project page.


I used xgboost for which I simply increased the number of trees.

for n_trees in [10, 20, 50, 100, 200, 1000, 1500, 2000, 3000, 4000, 10000, 30000]:
    for max_depth in [5]:
        for learning_rate in [0.001]:
            yield {
                    "name": "XGBClassifier",
                    "parameters": {
                        "n_estimators": n_trees,
                        "max_depth": max_depth,
                        "learning_rate": learning_rate,

On the following datasets:

['heart_statlog', 'hepatitis', 'horse_colic', 'house_votes_84', 'hungarian', 'hypothyroid', 'ionosphere', 'iris', 'irish', 'kr_vs_kp', 'krkopt', 'labor', 'led24', 'led7', 'letter', 'lupus', 'lymphography', 'magic']

The performance of each model was then evaluated using a 5 folds cross validation for the following metrics ["roc_auc", "accuracy", "neg_log_loss"].

Does xgboost overfit ?

The graphs below seem to say that increasing the number of trees may harm the performance of the model. However, in some cases, even very large number of trees are beneficial to the model.


Accuracy Log Loss ROC AUC

Comparison with random forest

However, the comparison with random forests is needed to understand what was at stake above:

Accuracy Log Loss ROC AUC


from pmlb import fetch_data
import pandas as pd
import numpy as np
from sklearn.model_selection import cross_validate
from model_builder import model_builder

def run(model_parameters, dataset_key, metrics):
    model = model_builder(model_parameters)
    X, y = fetch_data(dataset_key, return_X_y=True)
    if len(np.unique(y)) != 2:
        print("Problem is not binary")
        return None

    res = cross_validate(model, X, y, scoring=metrics)

    row = {
            "fit_time": np.mean(res["fit_time"]),
            "n": X.shape[0],
            "p": X.shape[1]
    for metric in metrics:
        row[metric] = np.mean(res[f"test_{metric}"])
        row[f"{metric}_std"] = np.std(res[f"test_{metric}"])
    return row

def benchmark():

    for n_trees in [10, 20, 50, 100, 200, 1000, 1500, 2000, 3000, 4000, 10000, 30000]:
        for max_depth in [5]:
            for learning_rate in [0.001]:
                yield {
                        "name": "XGBClassifier",
                        "parameters": {
                            "n_estimators": n_trees,
                            "max_depth": max_depth,
                            "learning_rate": learning_rate,

dataset_keys = ['heart_statlog', 'hepatitis', 'horse_colic', 'house_votes_84', 'hungarian', 'hypothyroid', 'ionosphere', 'iris', 'irish', 'kr_vs_kp', 'krkopt', 'labor', 'led24', 'led7', 'letter', 'lupus', 'lymphography', 'magic']
rows = []

for dataset_key in dataset_keys:
    for model in benchmark():
        row = run(model, dataset_key, ["roc_auc", "accuracy", "neg_log_loss"])
        if row is not None:
            row["model"] = model["name"]
            row["data"] = dataset_key
            row = {**row, **model["parameters"]}

output_data = pd.DataFrame(rows)
output_data.to_csv("./xgb_overfit.csv", index=False)

And the graphs can be produced easily using the nice .plot() methods proposed by pandas and matplotlib.

import pandas as pd
import matplotlib.pyplot as plt

for metric in ["roc_auc", "accuracy", "neg_log_loss"]:
    benchmark_data = pd.read_csv("xgb_overfit.csv")
    benchmark_data.set_index("n_estimators", inplace=True)
    benchmark_data.groupby("data")[metric].plot(legend=True, logx=True, title=metric)
    plt.legend(loc="lower left")


  • xgboost one of the best gradient boostin libraries available.

  • pmlb a Python library providing various benchmark datasets.

Keras memory leak

Keras memory usage keeps increasing

I was having fun, attempting to do some deep learning with a 2M lines dataset (nothing my computer can’t handle, xgboost was running with roughly 15% of my RAM) when suddenly, as I was adding neural networks in my fancy stacked models, the script kept failing, the memory usage went to the moon, etc, etc.

What did I do wrong ? Did I introduce a memory leak between my model stacking / neural network factory code ? I would be suprised, it worked fine with every other model. And a neural network is more or less a simple vector of floats (in my case, with only hundreds of parameters) so there is no reason for it to be that big.

The only thing I was attempting to do was to cross validate different neural networks, with different architectures.

So, after a quick research : I found this stack overflow question , also some people mentioning a weird behavior coming from model.predict() . Another Github issue is simply called Memory leak . There even is another article simply titled Dealing with memory leak issue in Keras model training and is even mentioned on twitter .

What I ended up suspecting is that there are actually many memory leaks from different methods in the code. So I gathered the list of workarounds I could find.


Beware, none of them actually works. Some just alleviate the pain, but most likely, the memory usage will keep increasing. Anyways, the good news is that, combining many of the tricks I could read, I managed to have my models run ;)

Garbage collecting

Generally, when you see these lines in the code it means that the person who wrote it was desperate to make it run while closely monitoring the memory usage of the script and combined tricks not to make sure everything was fitting into the memory. Usually, performing tasks in dedicated functions and trusting the garbage collector to do its job at the right time is enough. But sometimes you meet these del / garbage collector random invokations.

import gc
del model

I did put these lines after every I found. They did not help at all in my case.

Force eager evaluation

This one kind of worked for me. It slows down the training (3 times slower in my case), the memory keeps increasing for no reason, but much less. Just add the following argument in the model.compile() method :

model.compile( [...]

model(x) instead of model.predict(x)

Some people mentioned it. It did not change a thing for me, but I wrote it that way. Be careful though, model(x) will return a tensorflow object while model.predict(x) will return a numpy object.

Run it in a dedicated script

Yes, kind of ugly. It does not solve the issue, but if you make your cross validation in a python script, itself being called from the terminal level, you can pass parameters using JSON and hope that each script won’t hit your memory limit.

In my case, I wrote the following class:

from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
import tensorflow as tf
import gc
import numpy as np

class NNModel:

    def __init__(self, architecture, epochs, loss="binary_crossentropy", optimizer="adam"):
        self._epochs = epochs
        self._loss = loss
        self._optimizer = optimizer
        self._architecture = architecture
        self._model = None

    def fit(self, X, y):

        self._model = self._model_factory(X.shape[1])

        X_tf = tf.convert_to_tensor(X, dtype=tf.float32)
        y_tf = tf.convert_to_tensor(y, dtype=tf.float32), y_tf, epochs=self._epochs)
        return self

    def _model_factory(self, input_dim):

        model = Sequential()

        architecture = self._architecture.copy()
        first_layer = architecture.pop(0)

        model.add(Dense(first_layer[0], input_dim=input_dim, activation=first_layer[1]))
        for layer in architecture:
            model.add(Dense(layer[0], activation=layer[1]))


        return model

    def predict(self, X):
        raise NotImplementedError

    def predict_proba(self, X):
        X_tf = tf.convert_to_tensor(X, dtype=tf.float32)
        res =  self._model(X_tf)
        res = np.hstack((1-res, res))
        return res

Which I can configure using a JSON that will contain the arguments of the class constructor:

  "epochs": 8,
  "architecture": [[ 12, "relu" ], [ 8, "relu" ], [ 1, "sigmoid" ]]

And then I invoke them with:

find ../models/ -name \*.json | xargs --max-args=1 python

So that I can run my different models while I am sure that the memory will be totally released between the execution of two scripts.


Quoting MProx from a git issue

I have managed to get around this error by using model.predict_on_batch() instead of model.predict(). This returns an object of type <class ‘tensorflow.python.framework.ops.EagerTensor’> - not a numpy array as claimed in the docs - but it can be cast by calling np.array(model.predict_on_batch(input_data)) to get the output I want.

Side note: I also noticed a similar memory leak problem with calling in a loop, albeit with a slower memory accumulation, but this can be fixed in a similar way using model.train_on_batch().

I did not try this one, as segregating different models in different scripts and setting run_eagerly did the job.

Use tf-nightly

So, tf-nightly is built more or less every day, with the latest features and less tests. Many people claimed that the leak disapeared when using this library. But there are many versions, with potentially other bugs.

re install the 1.14 version

This bug has been around for a while, some tickets mention it from october 2019 and it is still present in the 2.4 version.


I look forward to this issue being solved.