Self-Organizing Maps on the MNIST dataset
This notebook shows a standard Kohonen SOM (KSOM) and a Dynamic SOM (DSOM) on the MNIST dataset. The MNIST dataset consists of 60k 28x28 grayscale images of figures from 0 to 9. The goal is to project those images on a small 2D map where nearby nodes correspond to nearby images.
Data
Load the MNIST dataset and show the first elements
Model
Initialize the SOM (Kohonon or DSOM):
# SOM generic parameters:
shape = (12, 12)
topography = "square"
borderless = False
input_shape = (28, 28)
som_type = "ksom"
if som_type == "ksom":  # Kohonen
    params = smp.KsomParams(
        t_f=60000, sigma_i=0.7, sigma_f=0.01, alpha_i=0.1, alpha_f=0.001
    )
    model = smp.Ksom(shape, topography, borderless, input_shape, params)
elif som_type == "dsom":  # Dynamic SOM
    params = smp.DsomParams(alpha=0.001, plasticity=0.02)
    model = smp.Dsom(shape, topography, borderless, input_shape, params)
print("Visualisation of the weight values of each node:")
smp.plot(
    model,
    show_prototypes=True,
    show_activity=False,
)
Train the model:
epoch = 1
auxs = []
for i in range(0, epoch):
    model, aux = smp.make_steps(model, {"bu_v": data})
    auxs.append(aux)
# Concatenate the 'aux' outputs if there are several
aux = jax.tree_util.tree_map(
    lambda x, *y: np.concatenate((x, *y), axis=0),
    *auxs,
    is_leaf=lambda x: isinstance(x, list),
)
print(f"Nb of data samples viewed by the model: {epoch * data.shape[0]}")
smp.plot(model, show_prototypes=True, show_activity=False, img_inverted_colors=True)
Evaluation
Show the quantization error:
Show the topographic error:
Note: An error of 1/sqrt(w,h) (where w is the width and h is the height of the map) means that the two best winning nodes are neighbors