/*
 y = forward(p,fi,lm,b,b0)

    This file is part of SeDuMi 1.03BETA
    Copyright (C) 1999 Jos F. Sturm
    Dept. Quantitative Economics, Maastricht University, the Netherlands.
    Affiliations up to SeDuMi 1.02 (AUG1998):
      CRL, McMaster University, Canada.
      Supported by the Netherlands Organization for Scientific Research (NWO).
  
    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.
  
    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.
  
    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.

*/

#include <math.h>
#include <string.h>
#include "mex.h"
#include "blksdp.h"

/*  y = forward(p,fi,lm,b,b0) */
#define Y_OUT plhs[0]

#define P_IN prhs[0]
#define FI_IN prhs[1]
#define LM_IN prhs[2]
#define B_IN prhs[3]
#define B0_IN prhs[4]

/* ************************************************************
   PROCEDURE innerfw
   INPUT
     fi - order m+1 sparse vector, with nonzeros positions where
        we pivot on the dense column.
     p  - length m vector: dense column
     lm - the dense bottom row of the L-factor: L = [L11, 0; lm'].
         L11 is unit-diagonal, and mostly the identity matrix, except
         where we pivot on the dense column (fi.ir) positions: there
         it's p / p(i).
     m  - order of p.
   UPDATED
     y  - on input, the order m+1 rhs, on output, L*yOUT = yIN.
   ************************************************************ */
void innerfw(double *y,jcir fi,const double *p,const double *lm,const int m)
{
  int inz,i;
  double sumyp, yi;

  sumyp = 0.0;
  inz = fi.jc[0];
  i = fi.ir[inz];
  for(++inz; inz < fi.jc[1]; inz++){
    yi = y[i];
    y[i] = yi - sumyp * p[i];
    sumyp = yi / p[i];
    for(++i; i < fi.ir[inz]; i++)
      y[i] -= sumyp * p[i];
  }
  y[m] = (y[m] - realdot(lm,y,m)) / lm[m];
}


/* ============================================================
   MAIN: MEXFUNCTION
   ============================================================ */
/* ************************************************************
   PROCEDURE mexFunction - Entry for Matlab
   y = forward(p,fi,lm,b,b0)
   ************************************************************ */
void mexFunction(const int nlhs, mxArray *plhs[],
  const int nrhs, const mxArray *prhs[])
{
  int m,j,N;
  const double *p,*lm,*b,*b0;
  jcir fi;
  double *y;

/* ------------------------------------------------------------
   Check for proper number of arguments
   ------------------------------------------------------------ */
  if(nrhs < 5)
    mexErrMsgTxt("forward requires 5 input arguments.");
  if(nlhs > 1)
    mexErrMsgTxt("forward generates 1 output argument.");
 /* ------------------------------------------------------------
    Get input vectors p, fi, lm, rhs m x N matrix  b and 1xN row  b0.
    ------------------------------------------------------------ */
  m = mxGetM(P_IN) * mxGetN(P_IN);
  p = mxGetPr(P_IN);
  if(mxGetM(LM_IN) * mxGetN(LM_IN) != m+1)
    mexErrMsgTxt("lm size mismatch.");
  lm = mxGetPr(LM_IN);
  b = mxGetPr(B_IN);
  if(mxGetM(B_IN) != m)
    mexErrMsgTxt("b size mismatch.");
  N = mxGetN(B_IN);
  b0 = mxGetPr(B0_IN);
  if(mxGetN(B0_IN) != N || mxGetM(B0_IN) != 1)
    mexErrMsgTxt("b0 size mismatch.");
  if(!mxIsSparse(FI_IN))
    mexErrMsgTxt("fi should be sparse.");
  if(mxGetM(FI_IN) != m+1 || mxGetN(FI_IN) != 1)
    mexErrMsgTxt("fi size mismatch.");
  fi.jc = mxGetJc(FI_IN);
  fi.ir = mxGetIr(FI_IN);
  fi.pr = mxGetPr(FI_IN);
/* ------------------------------------------------------------
   Create output matrix y(m+1,N)
   ------------------------------------------------------------ */
  Y_OUT = mxCreateDoubleMatrix(m+1, N, mxREAL);
  y = mxGetPr(Y_OUT);
/* ------------------------------------------------------------
   Initialize y = [b;b0]
   ------------------------------------------------------------ */
  for(j = 0; j < N; j++, b += m, y += m+1){
    memcpy(y, b, m * sizeof(double));
    y[m] = b0[j];
/* ------------------------------------------------------------
   Solve L(p,lm) * yNEW = yOLD
   ------------------------------------------------------------ */
    innerfw(y,fi,p,lm,m);
  }
}

