import utils import matplotlib.pyplot as plt from sklearn.cluster import KMeans def kmeans(n_clusters, slab_width, slab_height): slabs, i_min, i_max = utils.load_slabs(slab_width, slab_height) X = utils.as_np_array(slabs) f, axes = plt.subplots(1, 3) n_freqs, n_pixels = X.shape[0], X.shape[1] X = X.T km_estimator = KMeans(n_clusters=n_clusters) km_estimator.fit(X) centers = km_estimator.cluster_centers_ labels = km_estimator.labels_.reshape((slab_width, slab_height)) for i in range(X.shape[0]): axes[0].plot(utils.energies, X[i,:], c='blue', alpha=0.1) for i in range(centers.shape[0]): axes[0].plot(utils.energies, centers[i,:], c='black', alpha=1) axes[1].imshow(slabs[15]["data"], vmin=i_min, vmax=i_max) axes[2].imshow(labels) plt.show()