Self-Organizing Maps on the MNIST dataset
This notebook shows a standard Kohonen SOM (KSOM) and a Dynamic SOM (DSOM) on the MNIST dataset. The MNIST dataset consists of 60k 28x28 grayscale images of figures from 0 to 9. The goal is to project those images on a small 2D map where nearby nodes correspond to nearby images.
Data
Load the MNIST dataset and show the first elements
Model
Initialize the SOM (Kohonon or DSOM):
# SOM generic parameters:
shape = (12, 12)
topography = "square"
borderless = False
input_shape = (28, 28)
som_type = "ksom"
if som_type == "ksom": # Kohonen
params = smp.KsomParams(
t_f=60000, sigma_i=0.7, sigma_f=0.01, alpha_i=0.1, alpha_f=0.001
)
model = smp.Ksom(shape, topography, borderless, input_shape, params)
elif som_type == "dsom": # Dynamic SOM
params = smp.DsomParams(alpha=0.001, plasticity=0.02)
model = smp.Dsom(shape, topography, borderless, input_shape, params)
print("Visualisation of the weight values of each node:")
smp.plot(
model,
show_prototypes=True,
show_activity=False,
)
Train the model:
epoch = 1
auxs = []
for i in range(0, epoch):
model, aux = smp.make_steps(model, {"bu_v": data})
auxs.append(aux)
# Concatenate the 'aux' outputs if there are several
aux = jax.tree_util.tree_map(
lambda x, *y: np.concatenate((x, *y), axis=0),
*auxs,
is_leaf=lambda x: isinstance(x, list),
)
print(f"Nb of data samples viewed by the model: {epoch * data.shape[0]}")
smp.plot(model, show_prototypes=True, show_activity=False, img_inverted_colors=True)
Evaluation
Show the quantization error:
Show the topographic error:
Note: An error of 1/sqrt(w,h) (where w is the width and h is the height of the map) means that the two best winning nodes are neighbors