"""Spectral discretization of Hamiltonian for interacting fermions in 1D"""

__author__ = "Christian Clason <christian.clason@uni-graz.at>",\
             "Greg von Winckel < gregory.von-winckel@uni-graz.at>"
__date__ = "February 23, 2012"
__version__ = "1.0"

import numpy as np
import scipy as sp
from scipy.sparse import csr_matrix as csr
from scipy.linalg import eig_banded
from itertools import izip, combinations as combs

def assembleMatrix(p, N, Ufun, Vfun, bc):
    """
    Assemble stiffness and potential matrices for many-particle Hamiltonian
    
    Parameters
    ----------
    p : int
        number of spectral basis functions per dimension
    N : int
        number of particles
    Ufun : function
        description of confinement potential 
        (e.g., Ufun = lambda x: x)
    Vfun : function
        description of pairwise interaction potential, takes pairs [x,y] 
        (e.g., Vfun = lambda x: numpy.cos(x[:,1]-x[:,0]))
    bc : string ['periodic' | 'dirichlet']
        choose periodic or homogenenous Dirichlet boundary conditions

    Returns
    -------
    K : compressed sparse row matrix
        stiffness matrix
    U : compressed sparse row matrix
        confinement potential matrix
    V : compressed sparse row matrix
        interaction potential matrix

    Examples
    --------
    >>> K, U = fermi.assembleMatrix(15, 4, lambda x: x, lambda x: 0*x[:,0], 
                                    'dirichlet')

    >>> K, U, V = fermi.assembleMatrix(8, 3, lambda x: 0*x, 
                                lambda x: numpy.cos(x[:,0]-x[:,1]), 'periodic')
    """
    # Precompute
    xi, K1d = computeLGL(p, bc)       # 1D Legendre-Gauss-Lobatto quadrature 
    diagK1d = np.diag(K1d)            # diagonal elements of 1D stiffness matrix
    d = list(combs(xrange(N), 2))     # all pairs of interacting particles
    T = list(combs(xrange(p), N))     # matrix of nD basis elements tuples
    Np = len(T)                       # number of nD basis elements
    T_d = dict(izip(T, xrange(Np)))   # dictionary for fast lookup of tuples
    T_a = np.uint(T)                  # tuples as matrix for indexing
    T_g = np.hstack((T_a, np.ones((Np,1))*p))  # augment T to account for G_N
    G = np.uint(np.diff(T_g)-1)       # G_alpha (length of gaps in alpha)
    vNp = np.arange(Np)               # vector (1, ..., Np)
    iN = xrange(N+1)                  # iterator (1, ..., N+1)

    # Compute potential terms
    Vc = map(Vfun, [xi[T_a[:,np.uint(d[j])]] for j in xrange(len(d))])
    V = csr((np.vstack((np.asarray(Vc)).T).sum(1), (vNp,vNp)), shape=(Np,Np))
    U = csr((Ufun(xi[T_a]).sum(1), (vNp,vNp)), shape=(Np,Np))

    # Compute stiffness matrix
    Nk = Np*N*(p-N)/2                 # number of nonzero offdiagonal entries
    row = np.empty((Nk))              # preallocate vector of row indices of K
    col = np.empty((Nk))              # preallocate vector of column indices
    val = np.empty((Nk))              # preallocate vector of entries
    ind = 0                           # running index for above vectors

    for jj in xrange(Np):                     # loop over all possible tuples 
        alpha = list(T[jj])                   # tuple for column
        for nn in np.nonzero(G[jj,:]>0)[0]:   # insert at all indices in G_alpha
            for kk in xrange(G[jj,nn]):       # insert all possible values
                b = alpha[nn] + 1 + kk        # value to be inserted at beta_n
                betap = alpha[:]              # initialize, ...
                betap.insert(nn+1, b)         # ... augment tuple beta'
                sgn = int((-1)**nn)           # sign of K_jk: (-1)^(m+n)
                for mm in xrange(nn+1):       # remove at all indices m <= n
                    beta = [betap[i] for i in iN if i !=mm]  # tuple for row
                    row[ind] = jj                            # j = tau(alpha)
                    col[ind] = T_d[tuple(beta)]              # k = tau(beta)
                    val[ind] = sgn*K1d[alpha[mm],b]          # K_jk
                    ind += 1                 # increment running index
                    sgn *= -1                # flip sign

    Ko = csr((val, (row,col)), shape=(Np,Np))
    Kd = csr((diagK1d[T_a].sum(1), (vNp,vNp)), shape=(Np,Np))
    K = Kd + Ko + Ko.T

    return K, U, V
    
def computeLGL(p, bc): 
    """Compute Legendre-Gauss-Lobatto nodes and stiffness matrix"""
    if bc == 'periodic':
        xi = 2./p*np.pi*np.arange(p)
        Kf = np.diag(np.hstack((np.arange(np.ceil(p/2.)), 
                               -np.arange(-np.floor(p/2.), 0)))**2) 
        K  = np.real(sp.fft(sp.ifft(Kf).T))    # nodal stiffness matrix
    elif bc == 'dirichlet': 
        k = np.arange(1,p+2)                   # index
        a_band = np.zeros((2, p+2)) 
        a_band[1,0:p+1] = k*k/(4*k*k-1.)
        a_band[1,p] = (p+1.)/(2*p+1)           # modify coeffs for Lobatto nodes
        a_band = np.sqrt(a_band)
        x, V = eig_banded(a_band, lower=True)  # eigenvalue decomp of recurrence
        xi = x[1:-1]
        w = 2*(V[0,:]**2)                      # Lobatto weights
        e = np.ones(p+2)
        Xdiff = np.outer(x, e) - np.outer(e, x) + np.identity(p+2)
        W = np.outer(1/Xdiff.prod(1), e)
        D = W/np.multiply(W.T, Xdiff)
        D.flat[::D.shape[1]+1] = 0             # set diagonal elements
        Di = np.dot(-D.T, np.diag(1/np.sqrt(w)))[:,1:p+1]
        K = np.dot(np.dot(Di.T, np.diag(w)), Di)
    else:
        raise NotImplementedError('Boundary condition "' + bc +
                                  '" is not implemented.')

    return xi, K

if __name__ == '__main__':
    """ Test routine if called as script """
    import argparse
    from time import clock
    from scipy.sparse.linalg import eigsh

    parser = argparse.ArgumentParser(description='Test spectral discretization')
    parser.add_argument('-N', '--particles', dest='N', type=int, required=True,
                        help='number of particles')
    parser.add_argument('-p', '--degree', dest='p', type=int, required=True,
                        help='number of basis functions')
    args = parser.parse_args()
    N = args.N
    Ufun = lambda x: 0*x
    Vfun = lambda x: np.cos(x[:,0]-x[:,1])

    print "\nTest 1: N=%d, U(x)=0, V(x,y)=0, Dirichlet boundary conditions" %N
    tic = clock()
    K, U, V = assembleMatrix(args.p, args.N, Ufun, Vfun,'dirichlet')
    print "Time for set up:\t",  clock()-tic,  "seconds"
    print "\nComputing eigenvalues of stiffness matrix"
    evals = eigsh(K, 2, which='SM', return_eigenvectors=False)
    evals.sort()
    print "Computed eigenvalues:\t",  evals
    exact_evals = np.array([N*(N+1)*(2*N+1), (2*N+1)*(N**2+N+6)])*np.pi**2/24
    print "Exact eigenvalues:\t",  exact_evals
    print "Error:\t\t\t",  np.linalg.norm(evals-exact_evals)

    print "\nTest 2: N=3, U(x)=0, V(x,y)=cos(x-y), periodic boundary conditions"
    tic = clock()
    K, U, V = assembleMatrix(args.p, 3, Ufun, Vfun, 'periodic')
    print "Time for set up:\t",  clock()-tic,  "seconds"
    print "\nComputing eigenvalues of Hamiltonian matrix"
    evals = eigsh(K+U+V, 2, which='SM', return_eigenvectors=False)
    evals.sort()
    print "Computed eigenvalues:\t",  evals
    print "Reference eigenvalues:\t",  np.array([0.96420064, 3.96420064])

