34 lines
832 B
Python
34 lines
832 B
Python
|
import utils
|
||
|
import matplotlib.pyplot as plt
|
||
|
from sklearn.cluster import KMeans
|
||
|
|
||
|
def kmeans(n_clusters, slab_width, slab_height):
|
||
|
slabs, i_min, i_max = utils.load_slabs(slab_width, slab_height)
|
||
|
|
||
|
X = utils.as_np_array(slabs)
|
||
|
|
||
|
f, axes = plt.subplots(1, 3)
|
||
|
|
||
|
n_freqs, n_pixels = X.shape[0], X.shape[1]
|
||
|
|
||
|
X = X.T
|
||
|
|
||
|
km_estimator = KMeans(n_clusters=n_clusters)
|
||
|
km_estimator.fit(X)
|
||
|
centers = km_estimator.cluster_centers_
|
||
|
|
||
|
labels = km_estimator.labels_.reshape((slab_width, slab_height))
|
||
|
|
||
|
for i in range(X.shape[0]):
|
||
|
axes[0].plot(utils.energies, X[i,:], c='blue', alpha=0.1)
|
||
|
|
||
|
for i in range(centers.shape[0]):
|
||
|
axes[0].plot(utils.energies, centers[i,:], c='black', alpha=1)
|
||
|
|
||
|
axes[1].imshow(slabs[15]["data"], vmin=i_min, vmax=i_max)
|
||
|
|
||
|
axes[2].imshow(labels)
|
||
|
|
||
|
plt.show()
|
||
|
|