Implementing k Nearest Neighbours in OCaml

Reading time ~4 minutes


Here and there, people wonder about the possibility to use functionnal programming to approach machine learning. I decided to give it a try with some learning algorithms and noted that there are actually various options to use external libraries (unfortunately, nothing getting close to scikit learn’s maturity)

A quick reminder about the k-nearest neighbours. It was first described in the early 1950s and is often referred to as a “lazy learner”, as it merely stores the data, waiting to be provided with points to classify.

Then, in a metric space \((X,d)\) given a point \(y \in X\), and \(n\) labelled points \(S = (x_i,l_i) \in (\mathbb{R} \times \{0,1\}) ^ n\) it will return the most common labels among the \(k\) closest point to \(y\). Simple, isn’t it ? Actually, one does not need more information to implement it. So let’s get started.

The expressivity of ocaml

For fun, let’s see how easy it is to implement a k-nearest neighbours in ocaml. Note that we only need to retrieve the closest points from one point in an array of points. The method find_nearest_neighbours does this. Note how generic it is : the point can have any type (float array, string list…) as long as the distance operates on this type. Think about all the templates that should be written in other languages. And the compiler will tell me if types are incompatible (when Python would wait until an error appears).

(* Returns the k smallest elements of an array *)
let get_smallest_elements_i input_array k = 
  let n = Array.length input_array in
  let indices = Array.init n (fun x -> x) in
  for i = 0 to (k-1) do
    for j = (n-1) downto 1 do
      if input_array.(indices.(j-1)) > input_array.(indices.(j)) then begin
        let b = indices.(j-1) in
        indices.(j-1) <- indices.(j);
        indices.(j) <- b;
  Array.sub indices 0 k

(* Returns the k closest points from current_point in all_points *)
let find_nearest_neighbours current_point all_points k distance = 
  let distances = (fun x -> distance x current_point) all_points in
  get_smallest_elements_i distances k

(* Returns the most common labels among the neihbours *)
let predict nearest_neighbours labels = 
  let sum a b = a +. b in
  let k = Array.length nearest_neighbours in
  if Array.fold_left sum 0. (Array.init k (fun i -> labels.(nearest_neighbours.(i)))) > 0. then 1. else ~-.1.

Now we need a dataset to try the algorithm. Nothing really funny there.

(* Toy data *)
let max_length = 1.

let chessboard_boundary x y = if ((mod_float x  0.5) -. 0.25) *. ((mod_float y 0.5) -. 0.25) > 0. then 1. else ~-.1.

let circle_boundary x y = if (x**2. +. y**2.) > 0.5 then 1. else ~-.1.

let unlabelled_boundary x y = 2. ;;

(* Given a decision boundary, returns a data set and the associated labels *)
let make_data n_points decision_boundary =
  let output_data = Array.init n_points (fun _ -> (Array.make 2 0.)) in
  let output_label = Array.make n_points 0. in
  for i = 0 to (n_points-1) do
    output_data.(i).(0) <- Random.float max_length;
    output_data.(i).(1) <- Random.float max_length;
    output_label.(i) <- decision_boundary output_data.(i).(0) output_data.(i).(1)
  output_data, output_label

Now that we defined the points as arrays of floats, we need to implement distances on it.

let sum a b = a +. b in

(* Usual Euclide Distance *)
let euclide_distance x y =
  let squares_diff = Array.init (Array.length x) (fun i -> (x.(i) -. y.(i))**2.) in
  Array.fold_left sum 0. squares_diff

let manhattan_distance x y =
  let squares_diff = Array.init (Array.length x) (fun i -> abs (x.(i) -. y.(i)) ) in
  Array.fold_left sum 0. squares_diff

Gluing up all the pieces together :

open Knn
open Distances
open ToyDataset

(* Number of points in the training set*)
let n_points = int_of_string Sys.argv.(1)  ;; 

(* Parameter k of the kNN algorithm *)
let k = int_of_string(Sys.argv.(2)) ;;

(* Number of points in the training set *)
let n_test_points = 50 ;;

(* Train and test data*)
let train_data, labels = make_data n_points circle_boundary;;
let test_data, pseudo_labels = make_data n_test_points unlabelled_boundary ;;

(* For each point in the test set, stores the indices of the nearest neighbours *)
let nearest_neighbours = (fun x -> find_nearest_neighbours x train_data k euclide_distance) test_data;;

(* Evaluates and prints the accuracy of the model *)
let mismatches = ref 0. ;;

for l = 0 to (n_test_points-1) do
  pseudo_labels.(l) <- predict nearest_neighbours.(l) labels ; 
  if pseudo_labels.(l) <> (circle_boundary test_data.(l).(0) test_data.(l).(1)) then (mismatches := !mismatches +. 1.) else (); 

print_string ("Error rate : "^string_of_float(100. *. !mismatches /. (float_of_int n_test_points))^"%\n");

Now I recommend using ocamlbuild. It will save you loads of time. Especially with large projects. Assuming the latest part is called simply enter this in the terminal:

me$ ls

me$ ocamlbuild main.byte
Finished, 9 targets (1 cached) in 00:00:00.

Now, you just have to call the produced byte file with the first argument being the number of points to generate and the second one, the parameter \(k\).

me$ ./main.byte 100 5
Error rate : 4.%
me$ ./main.byte 1000 5
Error rate : 2.%
me$ ./main.byte 3000 5
Error rate : 0.%

What about performance ?

I leave this to another post : pypy vs ocaml for streaming learning, coming soon :)

More about knn

If you are interested in this method and further developments, you may find the following articles interesting:

[1]S. Cost and S. Salzberg, “A weighted nearest neighbor algorithm for learning with symbolic features,” Machine Learning, vol. 10, no. 1, pp. 57–78, Jan. 1993.

[2]J. Wang, P. Neskovic, and L. N. Cooper, “Improving nearest neighbor rule with a simple adaptive distance measure,” Pattern Recognition Letters, vol. 28, no. 2, pp. 207–213, Jan. 2007.

[3]K. Yu, L. Ji, and X. Zhang, “Kernel Nearest-Neighbor Algorithm,” Neural Processing Letters, vol. 15, no. 2, pp. 147–156, Apr. 2002.

OCaml List rev_map vs map

If you found this page, you are probably very familiar with OCaml already!So, OCaml has a ````map```` function whose purpose is pretty cl...… Continue reading

How to optimize PyTorch code ?

Published on March 17, 2024

Acronyms of deep learning

Published on March 10, 2024