Skip to content

Self-Organizing Maps on the MNIST dataset

This notebook shows a standard Kohonen SOM (KSOM) and a Dynamic SOM (DSOM) on the MNIST dataset. The MNIST dataset consists of 60k 28x28 grayscale images of figures from 0 to 9. The goal is to project those images on a small 2D map where nearby nodes correspond to nearby images.

%reset -f

import jax
import numpy as np
from matplotlib import pyplot as plt
from array2image import array_to_image

import somap as smp

Data

Load the MNIST dataset and show the first elements

data = smp.datasets.MNIST().data

print(f"Shape of dataset: {data.shape}")
print(f"Some data samples:")

array_to_image(data[:20])
Shape of dataset: (60000, 28, 28)
Some data samples:

No description has been provided for this image

Model

Initialize the SOM (Kohonon or DSOM):

# SOM generic parameters:
shape = (12, 12)
topography = "square"
borderless = False
input_shape = (28, 28)

som_type = "ksom"

if som_type == "ksom":  # Kohonen
    params = smp.KsomParams(
        t_f=60000, sigma_i=0.7, sigma_f=0.01, alpha_i=0.1, alpha_f=0.001
    )
    model = smp.Ksom(shape, topography, borderless, input_shape, params)

elif som_type == "dsom":  # Dynamic SOM
    params = smp.DsomParams(alpha=0.001, plasticity=0.02)
    model = smp.Dsom(shape, topography, borderless, input_shape, params)

print("Visualisation of the weight values of each node:")
smp.plot(
    model,
    show_prototypes=True,
    show_activity=False,
)
Visualisation of the weight values of each node:

Train the model:

epoch = 1
auxs = []
for i in range(0, epoch):
    model, aux = smp.make_steps(model, {"bu_v": data})
    auxs.append(aux)

# Concatenate the 'aux' outputs if there are several
aux = jax.tree_util.tree_map(
    lambda x, *y: np.concatenate((x, *y), axis=0),
    *auxs,
    is_leaf=lambda x: isinstance(x, list),
)

print(f"Nb of data samples viewed by the model: {epoch * data.shape[0]}")

smp.plot(model, show_prototypes=True, show_activity=False, img_inverted_colors=True)
Nb of data samples viewed by the model: 60000

Evaluation

Show the quantization error:

def moving_average(x, w):
    return np.convolve(x, np.ones(w), "valid") / w


errors = aux["metrics"]["quantization_error"]
plt.plot(moving_average(errors, 100))
[<matplotlib.lines.Line2D at 0x7f8b684249d0>]
No description has been provided for this image

Show the topographic error:

Note: An error of 1/sqrt(w,h) (where w is the width and h is the height of the map) means that the two best winning nodes are neighbors

errors = aux["metrics"]["topographic_error"]
plt.plot(moving_average(errors, 100))
[<matplotlib.lines.Line2D at 0x7f8b443964d0>]
No description has been provided for this image