##### dimer2.py (performs initial rotations)
##### Copyright: Olli-Pekka Koistinen, Aalto University, 9.7.2020
#####
##### This function uses the dimer method to converge to a saddle point,
##### starting from somewhere inside the convergence area.
##### A rotation step rotates the dimer (a pair of images) towards its minimum energy orientation
##### to find the lowest curvature mode of the potential energy.
##### A translation step moves the dimer towards the saddle point by inverting the
##### force component in the direction of the dimer.
#####
##### Input:
#####   pot_general        accurate potential and gradient function
#####                        (takes 'N_im' images as ndarray of shape 'N_im' x 'D',
#####                         and returns the potential energy at those images as ndarray of shape 'N_im' x 1
#####                         and the gradient of the potential energy as ndarray of shape 'N_im' x 'D')
#####   R_init             coordinates of the middle point of the initial dimer (ndarray of shape 1 x 'D')
#####   orient_init        unit vector along the direction of the initial dimer (ndarray of shape 1 x 'D')
#####   method_rot         a function defining the rotation step (see, e.g., 'utils_dimer.rot_iter_lbfgsext')
#####   method_trans       a function defining the translation step (see, e.g., 'utils_dimer.trans_iter_lbfgs')
#####   param_trans        parameters of the translation method (shape depends on 'method_trans')
#####   dimer_sep          dimer separation (distance from the middle point of the dimer to the two images) [default 0.01]
#####   T_dimer            final convergence threshold (the algorithm is stopped when
#####                        all components of force acting on the middle point of the dimer are less than this) [default 0.01]
#####   T_anglerot         convergence threshold for rotation angle (the dimer is
#####                        not rotated when the estimated rotation angle is less than this) [default 0.0873]
#####   num_iter_rot       maximum number of rotation iterations per translation [default 10]
#####   num_iter           maximum number of iterations [default 10000]
#####
##### Output:
#####   R                  coordinates of the middle point of the final dimer (ndarray of shape 1 x 'D')
#####   orient             unit vector along the direction of the final dimer (ndarray of shape 1 x 'D')
#####   E_R                energy at the middle point of the final dimer (ndarray of shape 1 x 1)
#####   G_R                gradient at the middle point of the final dimer (ndarray of shape 1 x 'D')
#####   R_all              coordinates of all observation points (ndarray of shape 'N_obs' x 'D')
#####   E_all              energies for all observation points (ndarray of shape 'N_obs' x 1)
#####   G_all              gradients for all observation points (ndarray of shape 'N_obs' x 'D')
#####   E_R_acc            energy at the middle point of the dimer for each iteration
#####   maxF_R_acc         maximum component of the force at the middle point of the dimer for each iteration
#####   obs_initrot        number of observations required for initial rotations
#####   obs_total          total number of observations
#####   Curv_initrot       curvature after initial rotations
#####   figs               figures

import numpy as np
import matplotlib.pyplot as plt
import utils_dimer
import pdb

def dimer(pot_general, R_init, orient_init, method_rot, method_trans, param_trans, dimer_sep=0.01, T_dimer=0.01, T_anglerot=0.0873, num_iter_rot=10, num_iter=10000):


    ###     
    ### THIS INFORMATION IS ASSUMED TO BE KNOWN BEFORE BEGINNING
    ###
    
    # dimension of the space (scalar):
    D = R_init.shape[1]
    # if 'orient_init' is empty, draw random unit vector:
    if orient_init.shape[1] < D:
        orient_init = np.random.normal(size=(1,D))
    orient_init = orient_init/np.sqrt(np.sum(np.square(orient_init)))

    ###
    ### THE ALGORITHM BEGINS HERE
    ###
    
    # coordinates of the middle point of the dimer:
    R = R_init.copy()
    # unit vector along the direction of the dimer:
    orient = orient_init.copy()
    # energy and gradient at the middle point of the dimer:
    E_R, G_R = pot_general(R)
    if E_R.ndim < 2:
        print('ERROR: Modify your energy function so that it returns two-dimensional ndarrays (of shape ''N_im'' x 1 and ''N_im'' x ''D''), even if there is only one image in the input (''N_im'' = 1)!')
        return
    # set zero level of biased potential to the energy of the middle point of the initial dimer:
    Elevel = E_R[0,0]
    # define biased potential with zero level at 'Elevel':
    pot_biased = lambda R : utils_dimer.subtract_Elevel(pot_general,R,Elevel)
    E_R = E_R - Elevel
    
    # coordinates of all observation points:
    R_all = R.copy()
    # energy for all observation points:
    E_all = E_R.copy()
    # gradient for all observation points:
    G_all = G_R.copy()
    # vector gathering the energy at the middle point of the dimer for each iteration:
    E_R_acc = np.ndarray(shape=(0))
    # vector gathering the maximum component of the force at the middle point of the dimer for each iteration:
    maxF_R_acc = np.ndarray(shape=(0))
    obs_initrot = 0
    Curv_initrot = 0
    
    # rotational force of the previous rotation iteration:
    rotinfo = {'F_rot_old': np.zeros((1,D))}
    # modified rotational force of the previous rotation iteration (in conjugated gradients method):
    rotinfo['F_modrot_old'] = np.zeros((1,D))
    # unit vector perpendicular to 'orient' within the rotation plane of the previous rotation iteration (in conjugated gradients method):
    rotinfo['orient_rot_oldplane'] = np.zeros((1,D))
    # number of conjugated rotation iterations (in conjugated gradients method):
    rotinfo['cgiter_rot'] = 0
    # maximum number of conjugated rotation iterations before resetting the conjugate directions (in conjugated gradients method):
    rotinfo['num_cgiter_rot'] = D
    # change of orientation in m previous rotation iterations (in L-BFGS):
    rotinfo['deltaR_mem'] = np.ndarray(shape=(0,D))
    # change of rotational force in m previous rotation iterations excluding the last one (in L-BFGS):
    rotinfo['deltaF_mem'] = np.ndarray(shape=(0,D))
    # maximum number of previous rotation iterations kept in memory (in L-BFGS):
    rotinfo['num_lbfgsiter_rot'] = D
    rotinfo['G1'] = np.ndarray(shape=(0,D))
    transinfo = {'potential': pot_biased}
    # translational force of the previous translation iteration (in conjugated gradients method):
    transinfo['F_trans_old'] = np.zeros((1,D))
    # modified translational force of the previous translation iteration (in conjugated gradients method):
    transinfo['F_modtrans_old'] = np.zeros((1,D))
    # velocity of the middle point of the dimer in the previous translation iteration (in quick-min velocity-Verlet):
    transinfo['V_old'] = np.zeros((1,D))
    # indicator if zero velocity used (in quick-min velocity-Verlet):
    transinfo['zeroV'] = 1
    # number of conjugated transition iterations (in conjugated gradients method):
    transinfo['cgiter_trans'] = 0
    # maximum number of conjugated transition iterations before resetting the conjugate directions (in conjugated gradients method):
    transinfo['num_cgiter_trans'] = D
    # change of location in m previous translation iterations (in L-BFGS):
    transinfo['deltaR_mem'] = np.ndarray(shape=(0,D))
    # change of translational force in m previous translation iterations excluding the last one (in L-BFGS):
    transinfo['deltaF_mem'] = np.ndarray(shape=(0,D))
    # maximum number of previous translation iterations kept in memory (in L-BFGS):
    transinfo['num_lbfgsiter_trans'] = D
    
    figs = []
    # in case of 2D space, plot the energy surface
    if D == 2:
        X1, X2 = np.meshgrid(np.arange(-2,0.5,0.01),np.arange(-0.4,1.8,0.01))
        E_true = np.zeros(X1.shape)
        for j in range(X1.shape[1]):
            E_truej, G_truej = pot_biased(np.hstack((X1[:,j][np.newaxis].T,X2[:,j][np.newaxis].T)))
            E_true[:,j] = E_truej[:,0]
        fig1 = plt.figure(1)
        plt.contourf(X1,X2,E_true,100)
        plt.axis('equal')
        plt.axis('tight')
        plt.jet()
        plt.colorbar()
        figs.append(fig1)

    for ind_iter in range(num_iter+1):
        
        # in case of 2D space, plot the dimer:
        if D == 2:
            plt.plot(np.array([[R[0,0]-dimer_sep*orient[0,0]],[R[0,0]+dimer_sep*orient[0,0]]]),np.array([[R[0,1]-dimer_sep*orient[0,1]],[R[0,1]+dimer_sep*orient[0,1]]]),'y-',markerFaceColor='y')
        
        # stop the algorithm if converged:
        E_R_acc = np.hstack((E_R_acc,E_R[0,0]))
        maxF_R = np.max(np.abs(G_R))
        maxF_R_acc = np.hstack((maxF_R_acc,maxF_R))
        if maxF_R < T_dimer:
            print('Stopped the algorithm: converged after {:g} iterations ({:g} image evaluations).\n'.format(ind_iter,E_all.shape[0]))
            break
        
        # stop the algorithm if maximum number of iterations reached:
        if ind_iter == num_iter:
            print('Stopped the algorithm: maximum number of iterations ({:g}) reached.\n'.format(ind_iter))
            break
        
        # evaluate energy and gradient at image 1:
        R1 = R + dimer_sep*orient
        E1, G1 = pot_biased(R1)
        R_all = np.vstack((R_all,R1))
        E_all = np.vstack((E_all,E1))
        G_all = np.vstack((G_all,G1))

        # if necessary, rotate the dimer and re-evaluate energy and gradient at image 1:
        if ind_iter < 1:
            num_iter_rot2 = np.max((D,num_iter_rot))
        else:
            num_iter_rot2 = num_iter_rot
        for ind_iter_rot in range(1,num_iter_rot2+1):
            orient_old = orient.copy()
            orient, Curv, R_obs, E_obs, G_obs = method_rot(R,orient,np.vstack((G_R,G1)),pot_biased,dimer_sep,T_anglerot,1,rotinfo)
            if R_obs.shape[0] < 1:
                break
            else:
                R_all = np.vstack((R_all,R_obs))
                E_all = np.vstack((E_all,E_obs))
                G_all = np.vstack((G_all,G_obs))
                # in case of 2D space, plot the dimer:
                if D == 2:
                    plt.plot(np.array([[2*R[0,0]-R_obs[0,0]],[R_obs[0,0]]]),np.array([[2*R[0,1]-R_obs[0,1]],[R_obs[0,1]]]),'y-',markerFaceColor='y')
                    plt.plot(np.array([[R[0,0]-dimer_sep*orient[0,0]],[R[0,0]+dimer_sep*orient[0,0]]]),np.array([[R[0,1]-dimer_sep*orient[0,1]],[R[0,1]+dimer_sep*orient[0,1]]]),'y-',markerFaceColor='y')
                if ind_iter_rot == num_iter_rot2 or np.arccos(np.dot(orient[0,:],orient_old[0,:])) < T_anglerot:
                    break
                elif rotinfo['G1'].shape[0] < 1:
                    R1 = R + dimer_sep*orient
                    E1, G1 = pot_biased(R1)
                    R_all = np.vstack((R_all,R1))
                    E_all = np.vstack((E_all,E1))
                    G_all = np.vstack((G_all,G1))
                else:
                    G1 = rotinfo['G1'].copy()
                    rotinfo['G1'] = np.ndarray(shape=(0,D))
        rotinfo['deltaR_mem'] = np.ndarray(shape=(0,D))
        rotinfo['deltaF_mem'] = np.ndarray(shape=(0,D))
        
        # translate the dimer and re-evaluate energy and gradient at the middle point:
        if Curv == None:
            Curv = np.dot((-G_R[0,:]+G1[0,:]),orient[0,:])/dimer_sep
        if iter < 1:
            Curv_initrot = Curv
            obs_initrot = E_all.shape[0]
        R, R_obs, E_obs, G_obs = method_trans(R,orient,-G_R,Curv,param_trans,transinfo)
        if R_obs.shape[0] > 0:
            R_all = np.vstack((R_all,R_obs))
            E_all = np.vstack((E_all,E_obs))
            G_all = np.vstack((G_all,G_obs))
            # in case of 2D space, plot the test dimer:
            if D == 2:
                plt.plot(np.array([[R_obs[0,0]-dimer_sep*orient[0,0]],[R_obs[0,0]+dimer_sep*orient[0,0]]]),np.array([[R_obs[0,1]-dimer_sep*orient[0,1]],[R_obs[0,1]+dimer_sep*orient[0,1]]]),'y-',MarkerFaceColor='y')
        # in case of 2D space, plot the dimer:
        if D == 2:
            plt.plot(np.array([[R[0,0]-dimer_sep*orient[0,0]],[R[0,0]+dimer_sep*orient[0,0]]]),np.array([[R[0,1]-dimer_sep*orient[0,1]],[R[0,1]+dimer_sep*orient[0,1]]]),'y-',MarkerFaceColor='y')
        E_R, G_R = pot_biased(R)
        R_all = np.vstack((R_all,R))
        E_all = np.vstack((E_all,E_R))
        G_all = np.vstack((G_all,G_R))

    # in case of 2D space, emphasize the final dimer and plot all observation points:
    if D == 2:
        plt.plot(np.array([[R[0,0]-dimer_sep*orient[0,0]],[R[0,0]+dimer_sep*orient[0,0]]]),np.array([[R[0,1]-dimer_sep*orient[0,1]],[R[0,1]+dimer_sep*orient[0,1]]]),'r-',MarkerFaceColor='y',LineWidth=2)
        plt.plot(R_all[:,0],R_all[:,1],'r+',markerFaceColor='r')
    
    obs_total = E_all.shape[0]

    return R, orient, E_R, G_R, R_all, E_all, G_all, E_R_acc, maxF_R_acc, obs_initrot, obs_total, Curv_initrot, figs
