import utils
import numpy as np
from sklearn.decomposition import FastICA
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

def ica(n_components, slab_width, slab_height):
    slabs, i_min, i_max = utils.load_slabs(slab_width, slab_height)

    X = utils.as_np_array(slabs).T
    n_freqs, n_pixels = X.shape[1], X.shape[0]

    p = FastICA(n_components = n_components, whiten="arbitrary-variance")
    Y = p.fit_transform(X)
    A = p.mixing_

    new_coords = [Y[:,i].reshape((slab_height, slab_width)) for i in range(n_components)]

    f = plt.figure(layout="constrained")
    gs = GridSpec(2, 3, figure=f)
    axes = [f.add_subplot(gs[0,0]), f.add_subplot(gs[1,0]),
            f.add_subplot(gs[0,1]), f.add_subplot(gs[0,2]),
            f.add_subplot(gs[1,1]), f.add_subplot(gs[1,2])
            ]

    axes[0].scatter(Y[:,0], Y[:,1], s=20, alpha=0.02)

    axes[1].imshow(slabs[15]["data"], vmin=i_min, vmax=i_max)
    axes[1].set_title(f"data @ {slabs[15]["e"]}cm-1")

    for c in range(n_components):
        axes[c+2].imshow(new_coords[c])
        axes[c+2].set_title(f"Component #{c+1}")

    plt.show()
#ica(3, 512, 512)