function [varargout] = tf_solve(w,x,y,ss,xt,filteronly)
% TF_SOLVE - Solve TP regression problem by filtering
%
% Syntax:
%   [...] = tf_solve(w,x,y,k,xt)
%
% In:
%   w     - Log-parameters (nu, sigma2, theta)
%   x     - Training inputs
%   y     - Training outputs
%   ss    - State space model function handle, [F,L,Qc,...] = @(x,theta) 
%   xt    - Test inputs (default: empty)
%   filteronly - Run only filter (default: false)
%
% Out (if xt is empty or not supplied):
%
%   e     - Negative log marginal likelihood
%   eg    - ... and its gradient
%
% Out (if xt is not empty):
%
%   Eft   - Predicted mean
%   Varft - Predicted marginal variance
%   Covft - Predicted joint covariance matrix
%   lb    - 95% confidence lower bound
%   ub    - 95% confidence upper bound
% 
% Description:
%   Consider the following TP regression [1] problem:
%
%       f ~ TP(0,k(x,x'),nu),
%     y_k = f(x_k),  k=1,2,...,n,
%
%   where k(x,x') = k_theta(x,x') + sigma2*delta(x,x').
%   The state space model is specified by the function handle 'ss' such
%   that it returns the state space model matrices
%
%     [F,L,Qc,H,Pinf,dF,dQc,dPinf] = ss(x,theta),
%
%   where theta holds the hyperparameters. See the paper [1] for details.
%   NOTE: This code is proof-of-concept, not optimized for speed.
% 
% References:
%
%   [1] A. Solin and S. Sarkka. State space methods for efficient 
%       inference in Student-t process regression. In Proceedings of the 
%       18th International Conference on Artificial Intelligence and 
%       Statistics (AISTATS), volume 38 of JMLR W&CP, pages 885-893, 2015.
% 
% Copyright:
%   2014-2015   Arno Solin
%
%  This software is distributed under the GNU General Public
%  License (version 3 or later); please refer to the file
%  License.txt, included with the software, for details.

%% Check defaults

  % Is there test data
  if nargin < 5, xt = []; end
  
  % Is 'filteronly' set
  if nargin < 6 || isempty(filteronly), filteronly = false; end
  
  % Jitter sigma2 for added numerical stability
  jitterSigma2 = 1e-9;  
  
  
%% Figure out the correct way of dealing with the data

  % Combine observations and test points
  xall = [x(:); xt(:)];
  yall = [y(:); nan(numel(xt),1)];
    
  % Make sure the points are unique and in ascending order
  [xall,sort_ind,return_ind] = unique(xall,'first');
  yall = yall(sort_ind);
    
  % Only return test indices
  return_ind = return_ind(end-numel(xt)+1:end);
  
  
%% Set up model

  % Log transformed parameters
  param = exp(w);
  
  % Extract values
  n      = numel(x);
  nu     = param(1);
  sigma2 = param(2);
  
  % Form the state space model
  [F,L,Qc,H,Pinf,dF2,dQc2,dPinf2] = ss(x,param(3:end));
  
  % Augment noise model
  F     = blkdiag(-inf,F);
  L     = blkdiag(1,L);
  Qc    = blkdiag(sigma2,Qc);
  Pinf  = blkdiag(sigma2,Pinf);
  Hs    = [0 H];
  H     = [1 H];
  
  % Stack derivatives
  dF    = zeros(size(dF2)+1);    dF(1) = 0; dF(2:end,2:end,2:end) = dF2; 
  dQc   = zeros(size(dQc2)+1);   dQc(1) = 1; dQc(2:end,2:end,2:end) = dQc2;
  dPinf = zeros(size(dPinf2)+1); dPinf(1) = 1; 
  dPinf(2:end,2:end,2:end) = dPinf2;

  
%% Prediction of test inputs (filtering and smoothing)

  % Check that we are predicting
  if ~isempty(xt)

    % Initialize
    e = 0;
    beta = 0;
    neta = 0;
    gamma = 1;
    gammas = zeros(1,numel(xall));
    
    % Set initial state
    m = zeros(size(F,1),1);
    P = Pinf;
    
    % Allocate space for results
    MS = zeros(size(m,1),size(yall,1));
    PS = zeros(size(m,1),size(m,1),size(yall,1));
    A = zeros(size(F,1),size(F,2),size(yall,1));
    Q = zeros(size(P,1),size(P,2),size(yall,1));
    
    % Initial dt
    dt = inf;
    
    
    % ### Forward filter
    
    % The filter recursion
    for k=1:numel(yall)
        
        % Solve A using the method by Davison
        if (k>1)
            
            % Discrete-time solution (only for stable systems)
            dt_old = dt;
            dt = xall(k)-xall(k-1);
            
            % Should we calculate a new discretization?
            if abs(dt-dt_old) < 1e-12
                A(:,:,k) = A(:,:,k-1);
                Q(:,:,k) = Q(:,:,k-1);
            else
                A(:,:,k) = expm(F*dt);
                Q(:,:,k) = Pinf - A(:,:,k)*Pinf*A(:,:,k)';
            end
            
            % Prediction step
            m = A(:,:,k) * m;
            P = A(:,:,k) * P * A(:,:,k)' + Q(:,:,k)*gamma;
            
        end
        
        % Update step
        if ~isnan(yall(k))
            S = H*P*H';
            K = P*H'/S;
            v = yall(k,:)'-H*m;
            neta = neta+1;
            
            % Update beta
            beta = beta + v'*(S\v)*gamma;
            
            % The new scale factor
            newgamma = (nu-2+beta)/(nu-2+neta);
            
            % Update m and P
            m = m + K*v;
            P = (P - K*H*P)*newgamma/gamma;
            
            % Update factor
            gamma = newgamma;
            
            % Update energy (without jointly evaluating the gradient)
            % e = e + .5*log((nu-2)*pi) ...
            %       + .5*log(det(S)) ...
            %       + gammaln((nu+k-1)/2) ...
            %       - gammaln((nu+k)/2) ...
            %       + .5*log((nu+k-1)-2) - .5*log(nu-2) ...
            %       + .5*(nu+k)*log(1+v'*(S\v)/(nu+k-1-2));
            
        end
        
        % Store factors
        gammas(k) = gamma;
        
        % Store estimate
        MS(:,k)   = m;
        PS(:,:,k) = P;
        
    end
    
    % Return negative log marginal likelihood
    %if nargout < 3, varargout = {e,nan}; return; end

       
    % ### Backward smoother

    % Should we run the smoother?
    if ~filteronly
      
    % Allocate space for storing the smoother gain matrix
    GS = zeros(size(F,1),size(F,2),size(yall,1));

    % Rauch-Tung-Striebel smoother
    for k=size(MS,2)-1:-1:1
      
      % Smoothing step (using Cholesky for stability)
      PSk = PS(:,:,k);
      
      % Pseudo-prediction
      PSkp = A(:,:,k+1)*PSk*A(:,:,k+1)'+Q(:,:,k+1)*gammas(k);
      
      % Solve the Cholesky factorization
      [L,notposdef] = chol(PSkp,'lower');
      
      % Numerical problems in Cholesky, retry with jitter
      if notposdef>0
          jitter = sqrt(jitterSigma2)*diag(rand(size(A,1),1));
          L = chol(PSkp+jitter,'lower');
      end
      
      % Continue smoothing step
      G = PSk*A(:,:,k+1)'/L'/L;
      m = MS(:,k) + G*(m-A(:,:,k+1)*MS(:,k));
      
      % The Gaussian case
      %P = PSk + G*(P-PSkp)*G';
      
      % The Student-t case
      P = (PSk - G*PSkp*G')*(gammas(end)/gammas(k)) + G*P*G';
      
      % Store estimate
      MS(:,k)   = m;
      PS(:,:,k) = P;
      GS(:,:,k) = G;
      
    end
  
    end

    % Estimate the joint covariance matrix if requested
    if nargout > 2 && ~filteronly
      
      % Allocate space for results
      Covft = zeros(size(PS,3));
          
      % Lower triangular
      for k = 1:size(PS,3)-1
        GSS = GS(:,:,k);
        for j=1:size(PS,3)-k
          Covft(k+j,k) = Hs*(GSS*PS(:,:,k+j))*Hs';
          GSS = GSS*GS(:,:,k+j);
        end
      end
    
    end
  
    % These indices shall remain to be returned
    MS = MS(:,return_ind);
    PS = PS(:,:,return_ind);
    
    % Return mean
    Eft = Hs*MS;
    
    % Return variance
    if nargout > 1
        Varft = zeros(size(Hs,1),size(Hs,1),size(MS,2));
        for k=1:size(MS,2)
            Varft(:,:,k)  = Hs*PS(:,:,k)*Hs';
        end
    end
    
    % Return values
    varargout = {Eft(:),Varft(:)};
    
    % Also return joint covariance and upper/lower 95% bounds
    if nargout > 2
        
        % Join upper triangular and variances
        if ~filteronly
            Covft = Covft(return_ind,return_ind);
            Covft = Covft+Covft'+diag(Varft(:));
        else
            Covft = [];
        end
        
        % For the dof calculation
        if filteronly
            neta = arrayfun(@(xi) sum(x<xi),xt);
        else
            neta = n;
        end
        
        % The bounds
        lb = Eft(:) + tinv(0.025,nu+neta(:)).*sqrt(Varft(:));
        ub = Eft(:) + tinv(0.975,nu+neta(:)).*sqrt(Varft(:));
        varargout = {Eft(:),Varft(:),Covft,lb(:),ub(:)};
        
    end

  end
  
%% Evaluate negative log marginal likelihood and its gradient

  if isempty(xt)

    % Size of inputs
    d = size(F,1);
    nparam = numel(param)-1;
    steps = numel(yall);
      
    % Allocate space for results
    edata = 0;
    beta = 0;
    gdata = zeros(1,nparam);
    gamma = 1;
    dgamma = zeros(1,nparam);
    newdgamma = dgamma;
    
    % Set up
    Z  = zeros(d);
    m  = zeros(d,1);
    P  = Pinf;
    dm = zeros(d,nparam);
    dP = dPinf;
    dt = -inf;
    
    % Allocate space for expm results
    AA = zeros(2*d,2*d,nparam);
    
    % Loop over all observations
    for k=1:steps
        
        % The previous time step
        dt_old = dt;
        
        % The time discretization step length
        if (k>1)
            dt = xall(k)-xall(k-1);
        else
            dt = 0;
        end
        
        % Loop through all parameters (Kalman filter prediction step)
        for j=1:nparam
            
            % Should we recalculate the matrix exponential?
            if abs(dt-dt_old) > 1e-9
                
                % The first matrix for the matrix fraction decomposition
                FF = [ F        Z;
                      dF(:,:,j) F];
                
                % Solve the matrix exponential
                AA(:,:,j) = expm(FF*dt);
                
            end
            
            % Solve the differential equation
            foo     = AA(:,:,j)*[m; dm(:,j)];
            mm      = foo(1:d,:);
            dm(:,j) = foo(d+(1:d),:);
            
            % The discrete-time dynamical model
            if (j==1)
                A  = AA(1:d,1:d,j);
                Q  = Pinf - A*Pinf*A';
                PP = A*P*A' + Q*gamma;
            end
            
            % The derivatives of A and Q
            dA = AA(d+1:end,1:d,j);
            dAPinfAt = dA*Pinf*A';
            dQ = dPinf(:,:,j) - dAPinfAt - A*dPinf(:,:,j)*A' - dAPinfAt';
            
            % The derivatives of P
            dAPAt = dA*P*A';
            dP(:,:,j) = dAPAt + A*dP(:,:,j)*A' + dAPAt' + dQ*gamma + Q*dgamma(j); 
        end
        
        % Set predicted m and P
        m = mm;
        P = PP;
        
        % Start the filter update step and precalculate variables
        S = H*P*H';% + R;
        [LS,notposdef] = chol(S,'lower');
        
        % If matrix is not positive definite, add jitter
        if notposdef>0,
            jitter = jitterSigma2*diag(rand(size(S,1),1));
            [LS,notposdef] = chol(S+jitter,'lower');
            
            % Return nan to the optimizer
            if notposdef>0,
                varargout = {nan*edata,[nan nan*gdata]};
                return;
            end
        end
        
        % Continue update
        HtiS = H'/LS/LS';
        iS   = eye(size(S))/LS/LS';
        K    = P*HtiS;
        v    = yall(k) - H*m;
        vtiS = v'/LS/LS';

        % The new scale factor
        beta = beta + gamma*(vtiS*v);
        newgamma = gamma/(nu+k-2)*((nu+k-1-2)+vtiS*v);
        
        % Precalculate
        PmKSKt = P - K*S*K';
        
        % Loop through all parameters
        for j=1:nparam
            
            % Innovation mean and covariance derivative
            dv = -H*dm(:,j);
            dS = H*dP(:,:,j)*H';
            
            % Evaluate the derivative
            gdata(j) = gdata(j) ...
                + .5*sum(iS(:).*dS(:)) ...
                + (nu+k)/((nu+k-1)-2+(vtiS*v)) * ...
                  (vtiS*dv - .5*vtiS*dS*vtiS');
            
            % Derivative of scaling factor
            newdgamma(j) = dgamma(j)/(nu+k-2)*((nu+k-1-2)+vtiS*v) + ...
               gamma/(nu+k-2)*(2*vtiS*dv - vtiS*dS*vtiS');
            
            
            % Kalman filter update step derivatives
            dK        = dP(:,:,j)*HtiS - P*HtiS*dS/LS/LS';
            dm(:,j)   = dm(:,j) + dK*v - K*H*dm(:,j);
            dKSKt     = dK*S*K';
            dP(:,:,j) = (dP(:,:,j) - dKSKt - K*dS*K' - dKSKt')*newgamma/gamma ...
                        + 1/gamma*(newdgamma(j)-newgamma/gamma*dgamma(j))*PmKSKt;
                    
        end
        
        % Evaluate the energy (Student-t)
        edata = edata ... 
                + .5*log((nu-2)*pi) ...
                + sum(log(diag(LS))) ... % .5*log(det(S))
                + gammaln((nu+k-1)/2) ...
                - gammaln((nu+k)/2) ...
                + .5*log((nu+k-1)-2) - .5*log(nu-2) ...
                + .5*(nu+k)*log(1+(vtiS*v)/((nu+k-1)-2));
        
        % Finish filter update step
        m = m + K*v;
        P = PmKSKt*newgamma/gamma;
        
        % Update factor
        gamma = newgamma;
        dgamma = newdgamma;
        
    end
    
    % Compute the derivative w.r.t. nu    
    gdata = [nan gdata];
    
    gdata(1) = ...
          n/(2*(nu-2)) ...
        - .5*psi((nu+n)/2) ...
        + .5*psi(nu/2) ... 
        + .5*log(1 + beta/(nu-2)) ...
        - beta*(nu+n)/2/(nu-2)/(beta+nu-2);
    
    % Account for log-scale
    gdata = gdata.*exp(w);
    
    % Return negative log marginal likelihood and gradient
    varargout = {edata,gdata};

  end
  
  
  
