"""
This module solves the parabolic measure space control problem
     min 1/2 |y-yd|_L2^2 + alpha |u|_L2(M)  s.t.   y_t-\Delta y = u
using a semismooth Newton method, as described in the paper
    "Parabolic control problems in measure spaces with sparse solutions"
by Eduardo Casas, Christian Clason, and Karl Kunisch, submitted to
SIAM Journal on Control and Optimization.
Besides NumPy, SciPy is required in v0.11.0.dev or later (for sp.diags).
"""

__author__ = "Christian Clason <christian.clason@uni-graz.at>"
__date__ = "June 11, 2012"

import numpy as np
from scipy import sparse as sp
from scipy.sparse.linalg import spsolve
from matplotlib import pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D

Nh    = 64      # number of grid points in space
Nt    = 256     # number of grid points in time
alpha = 1e-1    # penalty parameter
nu    = 1e-1    # diffusion coefficient
maxit = 20      # max iterations in semismooth Newton method

# grid, target
x,t   = np.linspace(-1,1,Nh), np.linspace(0,2,Nt)
xx,tt = np.meshgrid(x,t) 
yd = (tt*(1-np.abs(xx))).flatten()

# setup finite element discretization
Ns  = Nh*Nt
h   = x[2]-x[1]  
tau = t[2]-t[1]  
D2 = sp.diags([-1/h,2/h,-1/h],[-1,0,1],shape=(Nh,Nh),format='csr')
Mx = sp.diags([h/6,2./3*h,h/6],[-1,0,1],shape=(Nh,Nh),format='csr')
Dt = sp.diags([-1/tau,1/tau],[-1,0],shape=(Nt,Nt),format='csr')
It = sp.identity(Nt,format='csr')

# space-time discretization
L  = sp.kron(Dt,Mx,format='csr') + nu*sp.kron(It,D2,format='csr')
Ms = sp.kron(It,Mx,format='csr')
Is = sp.identity(Ns,format='csr')
II = sp.kron(It,np.ones((Nh,1)),format='csr') 

# initialization
y = np.zeros(Ns)
p = np.zeros(Ns)
c = alpha*np.ones(Nt);  cq = II*c
l = 1
Ap = Ap_old = (p >  cq).astype(float)
Am = Am_old = (p < -cq).astype(float)

# continuation strategy
for gamma in 10**np.arange(12):
    print 'Solving for gamma = %1.0e' % (gamma)
    print 'Iter \t AS change \t |c|^2 \t\t Gradient norm \t Feasibility'

    # semismooth Newton iteration
    for it in xrange(maxit):
        F = - np.hstack([L.T*p - Ms*(y-yd),
                         L*y + gamma*(Ap*(p-cq) + Am*(p+cq)),
                         2*l*c + gamma*(II.T*(Am*(p+cq) - Ap*(p-cq))),
                         tau * np.sum(c**2) - alpha**2])

        Lpp = gamma*sp.spdiags(Ap+Am,0,Ns,Ns,format='csr')
        Lpc = gamma*sp.spdiags(Am-Ap,0,Ns,Ns,format='csr')*II
        Lcc = sp.spdiags(2*l + gamma*(II.T*(Ap+Am)),0,Nt,Nt,format='csr')

        H = sp.bmat([[ -Ms  , L.T   , None      , None],
                    [  L    , Lpp   , Lpc       , None],
                    [  None , Lpc.T , Lcc       , 2*c.reshape((Nt,1))],
                    [  None , None  , tau*2*c , None]], format='csr')

        dx = spsolve(H,F)
        y += dx[:Ns]
        p += dx[Ns:2*Ns]
        c += dx[2*Ns:-1];  cq = II*c
        l += dx[-1]

        Ap = (p >  cq).astype(float)
        Am = (p < -cq).astype(float)

        # terminate Newton iteration if active sets no longer change
        change = (Ap-Ap_old)+(Am-Am_old)
        update = len(change[change.nonzero()])
        ngrad = np.sqrt(np.dot(F,F))
        cviol = np.max(np.abs(np.maximum(0,p-cq)+np.minimum(0,p+cq)))
        print "%d \t %d \t\t %1.2e \t %e \t %e" \
                    % (it+1,update,tau*np.sum(c**2),ngrad,cviol)
        if update == 0: break

        Ap_old,Am_old = Ap,Am

    # terminate continuation if Newton iteration converged in one step
    if it == 0: break

# compute starting point for control, active sets
u = -gamma*(np.maximum(0,p-cq)+np.minimum(0,p+cq))
Ap = (-u + p > cq).astype(float)
Am = (-u + p < -cq).astype(float)

# compute optimal measure space control
print 'Solving original problem'
print 'Iter \t AS change \t |c|^2 \t\t Gradient norm \t Feasibility'

# semismooth Newton iteration
for it in xrange(maxit):
    F = - np.hstack([L*y - u,
                     L.T*p - Ms*(y-yd),
                     u + (Ap*(-u+p-cq) + Am*(-u+p+cq)),
                     2*l*c + II.T*(Am*(-u+p+cq) - Ap*(-u+p-cq)),
                     tau * np.sum(c**2) - alpha**2])

    Lpp = sp.spdiags(Ap+Am,0,Ns,Ns,format='csr')
    Lpc = sp.spdiags(Am-Ap,0,Ns,Ns,format='csr')*II
    Lcc = sp.spdiags(2*l + II.T*(Ap+Am),0,Nt,Nt,format='csr')

    H = sp.bmat([[ -Is    , L    , None  , None    , None ],
                [  None   , -Ms  , L.T   , None    , None ],
                [  Is-Lpp , None , Lpp   , Lpc     , None],
                [  -Lpc.T , None , Lpc.T , Lcc     , 2*c.reshape((Nt,1))],
                [  None   , None , None  , tau*2*c , None]] , format='csr')

    dx = spsolve(H,F)
    u += dx[:Ns]
    y += dx[Ns:2*Ns]
    p += dx[2*Ns:3*Ns]
    c += dx[3*Ns:-1]; cq = II*c
    l += dx[-1]

    Ap = (-u + p >  cq).astype(float)
    Am = (-u + p < -cq).astype(float)

    # terminate Newton iteration if active sets no longer change
    change = (Ap-Ap_old)+(Am-Am_old)
    update = len(change[change.nonzero()])
    ngrad = np.sqrt(np.dot(F,F))
    cviol = np.max(np.abs(np.maximum(0,p-cq)+np.minimum(0,p+cq)))
    print "%d \t %d \t\t %1.2e \t %e \t %e" \
                % (it+1,update,tau*np.sum(c**2),ngrad,cviol)
    if update == 0: break

    Ap_old,Am_old = Ap,Am

# plot target, state, linear interpolation of optimal control
def surf(f,title):
    ax = plt.figure().gca(projection='3d')
    ax.plot_surface(tt,xx,f.reshape((Nt,Nh)),linewidth=0,rstride=1,cstride=1,
            cmap=cm.summer)
    ax.set_xlabel('t')
    ax.set_ylabel('x')
    ax.set_title(title)

surf(yd,'target')
surf(y,'state')
surf(u,'control (linear interpolation)')

# plot non-zero coefficients of optimal control
u[u==0] = np.NaN
ax = plt.figure().gca(projection='3d')
ax.scatter(tt.flatten(),xx.flatten(),u,c=u,marker='.',cmap=cm.summer)
ax.set_xlabel('t')
ax.set_ylabel('x')
ax.set_title('control (non-zero coefficients)')

# plot optimal bound c
plt.figure() 
plt.plot(t,c)
plt.xlabel('x')
plt.ylabel('c_\sigma(t)')
plt.title('norm bound')
plt.show()
