##### demo_muller_brown_sNEB.py
##### Copyright: Olli-Pekka Koistinen, Aalto University, 12.2.2020
#####
##### This script shows how to use 'sNEB.py' in a Muller-Brown example.

import numpy as np
import matplotlib.pyplot as plt
import utils
import sNEB
import muller_brown

pot_general = muller_brown.muller_brown # define the potential energy function

min1 = np.array([[-0.5582,1.4417]]) # define the first minimum point
min2 = np.array([[0.6235,0.0280]]) # define the second minimum point
N_im = 10 # define the number of images on the path
R_init = utils.initialize_path_linear(min1,min2,N_im) # define the initial path
method_step = utils.step_QMVelocityVerlet # define the step method (e.g., "qmVV" or "simple")
param_step = 0.1 # define parameters for the step method (time step in case of qmVV)
k_par = 1.0 # define the parallel spring constant
k_perp = 1.0 # define the perpendicular spring constant
T_MEP = 0.01 # define the convergence threshold for the maximum norm of the force
T_CI = 0.01 # define the additional convergence threshold for the climbing image
T_CIon = 0.01 # define the preliminary convergence threshold after which the climbing image mode is turned on (0 if not used at all)
num_iter = 10000 # define the maximum number of iterations
visualize = 1 # set visualization option on (requires large amount of extra evaluations so not to be used in real applications)

# Call the sNEB function:
R, E_R, G_R, i_CI, E_R_acc, normF_R_acc, normFCI_acc, figs = sNEB.sNEB(pot_general,R_init,method_step,param_step,k_par,k_perp,T_MEP,T_CI,T_CIon,num_iter)

fig = plt.figure()

sub1 = fig.add_subplot(121)
sub1.set_title('Magnitude of the force on one image')
sub1.plot(range(normF_R_acc.shape[1]),np.max(normF_R_acc,0),label='Max')
sub1.plot(range(normF_R_acc.shape[1]),np.mean(normF_R_acc,0),label='Mean')
sub1.plot(range(normF_R_acc.shape[1]),normFCI_acc,label='CI')
sub1.set_xlabel('iteration')
sub1.legend()

sub2 = fig.add_subplot(122)
sub2.set_title('Mean energy over the images')
sub2.plot(range(E_R_acc.shape[1]),np.mean(E_R_acc,0))
sub2.set_xlabel('iteration')

plt.show()


