##### NEB.py
##### Copyright: Olli-Pekka Koistinen, Aalto University, 12.2.2020
#####
##### This function uses the nudged elastic band (NEB)
##### algorithm with a climbing image option to find a minimum energy path
##### and a saddle point between two minimum points.
#####
##### Input:
#####   pot_general        accurate potential and gradient function
#####                        (takes 'N_im' images as ndarray of shape 'N_im' x 'D',
#####                         and returns the potential energy at those images as ndarray of shape 'N_im' x 1
#####                         and the gradient of the potential energy as ndarray of shape 'N_im' x 'D')
#####   R_init             coordinates for the images on the initial path (ndarray of shape 'N_im' x 'D')
#####   method_step        a function defining the following step during path relaxation (see, e.g., 'utils.step_QMVelocityVerlet')
#####   param_step         parameters of the path relaxation method (shape depends on 'method_step')
#####   k_par              parallel spring constant
#####   T_MEP              convergence threshold (the algorithm is stopped when
#####                        the norm of NEB force is less than this for all images)
#####   T_CI               additional convergence threshold for the climbing image
#####   T_CIon             preliminary convergence threshold after which the climbing image
#####                        mode is turned on (use 0 if CI not used at all)
#####   num_iter           maximum number of iterations
#####   visualize          1: visualizes the true energy along the path
#####                        (requires extra evaluations, so not to be used in real applications)
#####
##### Output:
#####   R                  coordinates for the images on the final path (ndarray of shape 'N_im' x 'D')
#####   E_R                energy at the images on the final path (ndarray of shape 'N_im' x 1)
#####   G_R                gradient at the images on the final path (ndarray of shape 'N_im' x 'D')
#####   i_CI               index of the climbing image among the intermediate images of the final path
#####   E_R_acc            energies of the images for each iteration
#####   normF_R_acc        norm of the NEB force acting on each intermediate image for each iteration
#####   normFCI_acc        norm of the NEB force acting on the climbing image for each iteration (0 if CI is off)
#####   figs               figures

import numpy as np
import utils
import matplotlib.pyplot as plt

def NEB(pot_general, R_init, method_step, param_step=0.01, k_par=1.0, T_MEP=0.1, T_CI=0.1, T_CIon=0.0, num_iter=10000, visualize=0):

    
    ###
    ### THIS INFORMATION IS ASSUMED TO BE KNOWN BEFORE BEGINNING
    ###

    # number of images on the path (scalar):
    N_im = R_init.shape[0]
    # dimension of the space (scalar):
    D = R_init.shape[1]
    # minimum point 1 (ndarray of shape 1 x 'D'):
    min1 = R_init[:1,:]
    # energy and gradient at minimum point 1 (ndarrays of shape 1 x 1 and 1 x 'D'):
    E_min1, G_min1 = pot_general(min1)
    if E_min1.ndim < 2:
        print('ERROR: Modify your energy function so that it returns two-dimensional ndarrays (of shape ''N_im'' x 1 and ''N_im'' x ''D''), even if there is only one image in the input (''N_im'' = 1)!')
        return
    # minimum point 2 (ndarray of size 1 x 'D'):
    min2 = R_init[-1:,:]
    # energy and gradient at minimum point 2 (ndarrays of shape 1 x 1 and 1 x 'D'):
    E_min2, G_min2 = pot_general(min2)
    # Elevel = np.min((E_min1,E_min2)) # zero level of energy is set to the lower minimum (scalar)
    Elevel = E_min1 # zero level of energy is set to minimum point 1 (scalar)
    E_min1 = E_min1 - Elevel
    E_min2 = E_min2 - Elevel


    ###
    ### THE ALGORITHM BEGINS HERE
    ###

    # coordinates of the images (ndarray of shape 'N_im' x 'D'):
    R = R_init.copy()
    # energies of the images (ndarray of shape 'N_im' x 1):
    E_R = np.vstack((E_min1,np.zeros((N_im-2,1)),E_min2))
    # gradients of the images (ndarray of shape 'N_im' x 'D'):
    G_R = np.vstack((G_min1,np.zeros((N_im-2,D)),G_min2))

    # in case of 2D space, plot the energy surface:
    if D == 2:
        scale1 = np.abs(min2[0,0]-min1[0,0])/4
        scale2 = np.abs(min2[0,1]-min1[0,1])/4
        X1, X2 = np.meshgrid(np.arange(np.min((min1[0,0],min2[0,0]))-2*scale1,np.max((min1[0,0],min2[0,0]))+scale1,scale1/20),np.arange(np.min((min1[0,1],min2[0,1]))-scale2,np.max((min1[0,1],min2[0,1]))+scale2,scale2/20))
        E_true = np.zeros(X1.shape)
        for j in range(X1.shape[1]):
            E_truej, G_truej = pot_general(np.hstack((X1[:,j][np.newaxis].T,X2[:,j][np.newaxis].T)))
            E_true[:,j] = E_truej[:,0]-Elevel
        fig1 = plt.figure(1)
        plt.contourf(X1,X2,E_true,100)
        plt.axis('equal')
        plt.axis('tight')
        plt.jet()
        plt.colorbar()

    R_CIon = np.ndarray(shape=(0,D))
    # ndarray gathering energies of the images for each iteration:
    E_R_acc = np.ndarray(shape=(N_im,0))
    # ndarray gathering norm of the NEB force acting on each intermediate image for each iteration:
    normF_R_acc = np.ndarray(shape=(N_im-2,0))
    # ndarray gathering norm of the NEB force acting on the climbing image for each iteration (0 if CI is off):
    normFCI_acc = np.ndarray(shape=(0))

    iters = 0
    # set climbing image mode off in the beginning:
    CI_on = 0
    # velocities of the intermediate images (given as an output of the previous step):
    V_old = np.zeros((N_im-2,D))
    # NEB forces on the intermediate images on the previous path:
    F_R_old = np.zeros((N_im-2,1))
    # indicator if zero velocity used (for the first iteration):
    zeroV = 1
    
    for ind_iter in range(num_iter+1):
        
        # calculate energy and gradient on the new path:
        E_R = E_R.copy()
        G_R = G_R.copy()
        E_R[1:-1,:], G_R[1:-1,:] = pot_general(R[1:-1,:])
        E_R[1:-1,:] = E_R[1:-1,:]-Elevel
        F_R, normFCI, i_CI = utils.force_NEB(R,E_R,G_R,k_par,CI_on)
        normF_R = np.sqrt(np.sum(np.square(F_R),1)[np.newaxis].T)

        # turn climbing image mode on and correct the NEB force accordingly if sufficiently relaxed:
        if CI_on <= 0 and np.max(normF_R) < T_CIon:
            CI_on = 1
            R_CIon = R.copy()
            E_R_CIon = E_R.copy()
            F_R, normFCI, i_CI = utils.force_NEB(R,E_R,G_R,k_par,CI_on)
            normF_R = np.sqrt(np.sum(np.square(F_R),1)[np.newaxis].T)
            zeroV = 1
            print('Climbing image turned on after {:g} iterations.\n'.format(iters))

        E_R_acc = np.hstack((E_R_acc,E_R))
        normF_R_acc = np.hstack((normF_R_acc,normF_R))
        normFCI_acc = np.hstack((normFCI_acc,normFCI))

        # stop relaxation if converged:
        if ( T_CIon <= 0 or CI_on > 0 ) and np.max(normF_R) < T_MEP and normFCI < T_CI:
            print('Stopped relaxation: converged after {:g} iterations ({:g} image evaluations).\n'.format(iters,(N_im-2)*(iters+1)))
            break

        # stop relaxation if maximum number of iterations reached:
        if iters == num_iter:
            print('Stopped relaxation: maximum number of iterations ({:g}) reached.\n'.format(iters))
            break

        # in case of 2D space, plot the path:
        if D == 2:
            plt.plot(R[:,0],R[:,1],'yo',markerFaceColor='y')

        # move the path one step along the NEB force according to the chosen method:
        R_new, V_old = method_step(R,F_R,param_step,F_R_old,V_old,zeroV)
        zeroV = 0
        
        # accept the step and continue relaxation:
        iters = iters + 1
        R = R_new.copy()
        F_R_old = F_R.copy()
        
    # in case of 2D space, plot the relaxed path
    if D == 2:
        plt.figure(1)
        if CI_on > 0:
            plt.plot(R_CIon[:,0],R_CIon[:,1],'ro',markerFaceColor='r')
            plt.plot(R[:,0],R[:,1],'ko',markerFaceColor='k')
        else:
            plt.plot(R[:,0],R[:,1],'ro',markerFaceColor='r')       
        figs = []
        figs.append(fig1)
    
    # visualize the true energy along the spline interpolation of the final
    # path and the path when the climbing image mode was set on
    # (requires large amount of extra evaluations, so not to be used in real applications)
    if visualize > 0:
        from scipy.interpolate import CubicSpline
        fig2 = plt.figure(2)
        csr = np.arange(0,N_im*10)/(10*N_im-1)
        if CI_on > 0:
            cs = CubicSpline(np.arange(0,N_im)/(N_im-1),R_CIon)
            R_spline_CIon = cs(csr)
            E_spline_CIon, G_spline_CIon = pot_general(R_spline_CIon)
            E_spline_CIon = E_spline_CIon - Elevel
            plt.plot(csr*(N_im-1)+1,E_spline_CIon,'r',LineWidth=2)
            plt.plot(np.arange(1,N_im+1),E_R_CIon,'o',MarkerEdgeColor='r',MarkerFaceColor='r',label='Path when CI turned on')
        cs2 = CubicSpline(np.arange(0,N_im)/(N_im-1),R)
        R_spline_final = cs2(csr)
        E_spline_final, G_spline_final = pot_general(R_spline_final)
        E_spline_final = E_spline_final - Elevel
        plt.plot(csr*(N_im-1)+1,E_spline_final,'b',LineWidth=2)
        plt.plot(np.arange(1,N_im+1),E_R,'o',MarkerEdgeColor='b',MarkerFaceColor='b',label='Final path')
        plt.title('True energy along cubic spline interpolation of the path')
        plt.xlabel('image number')
        plt.legend()
        figs.append(fig2)

    return R, E_R, G_R, i_CI, E_R_acc, normF_R_acc, normFCI_acc, figs

