SRS/kmeans.py
2024-12-18 17:07:43 +01:00

33 lines
832 B
Python

import utils
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
def kmeans(n_clusters, slab_width, slab_height):
slabs, i_min, i_max = utils.load_slabs(slab_width, slab_height)
X = utils.as_np_array(slabs)
f, axes = plt.subplots(1, 3)
n_freqs, n_pixels = X.shape[0], X.shape[1]
X = X.T
km_estimator = KMeans(n_clusters=n_clusters)
km_estimator.fit(X)
centers = km_estimator.cluster_centers_
labels = km_estimator.labels_.reshape((slab_width, slab_height))
for i in range(X.shape[0]):
axes[0].plot(utils.energies, X[i,:], c='blue', alpha=0.1)
for i in range(centers.shape[0]):
axes[0].plot(utils.energies, centers[i,:], c='black', alpha=1)
axes[1].imshow(slabs[15]["data"], vmin=i_min, vmax=i_max)
axes[2].imshow(labels)
plt.show()