/*
 bwdpr1 - 
    y = bwdpr1(Lden, b)  Solves
    "PROD_k L(pk,betak)' * y = b", where
    where L(p,beta) = eye(n) + tril(p*beta',-1).

    This file is part of SeDuMi 1.03BETA
    Copyright (C) 1999 Jos F. Sturm
    Dept. Quantitative Economics, Maastricht University, the Netherlands.
    Affiliations up to SeDuMi 1.02 (AUG1998):
      CRL, McMaster University, Canada.
      Supported by the Netherlands Organization for Scientific Research (NWO).
  
    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.
  
    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.
  
    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.

*/

#include <math.h>
#include <string.h>
#include "mex.h"
#include "blksdp.h"

/* y = bwdpr1fact(Lden,b) */
#define Y_OUT plhs[0]
#define NPAROUT 1

#define LDEN_IN prhs[0]
#define B_IN prhs[1]
#define NPARIN 2

/* ------------------------------------------------------------
   PROCEDURE bwipr1 - I(dentity) P(lus) R(ank)1 forward solve.
   INPUT:
   p    - length m floats
   beta - length n floats
   m, n - order of p and beta, resp. (n <= m)
   UPDATED:
   y - Length m. On input, contains the rhs. On output, the solution to
       L(p,beta)'*yNEW = yOLD. This updates only y(0:n-1).
   ------------------------------------------------------------ */
void bwipr1(double *y, const double *p, const double *beta,
            const int m, const int n)
{
  int i;
  double yi,t;

  if(n < 1)           /* If L = I, y remains the same */
    return;
/* ------------------------------------------------------------
   Let t = p(n:m-1)' * y
   ------------------------------------------------------------ */
  t = realdot(p+n, y+n, m-n);
/* ------------------------------------------------------------
   Solve yi for i = n-1:-1:0, from
   (eye(m)+triu(beta*p',1)) * yNEW = yOLD,
   i.e. yNEW(i) + betai*t = yOLD(i), with t := p(i+1:n-1)'*y.
   ------------------------------------------------------------ */
  for(i = n-1; i >= 0; i--){
    yi = (y[i] -= t * beta[i]);
    t += p[i] * yi;
  }
}

/* ------------------------------------------------------------
   PROCEDURE bwipr1o - I(dentity) P(lus) R(ank)1 forward solve, O(rdered).
   INPUT:
   perm - length m permutation on p and y.
   p    - length m floats
   beta - length n floats
   m, n - order of p and beta, resp. (n <= m)
   UPDATED:
   y - Length m. On input, contains the rhs. On output, the solution to
       L(p,beta)'*yNEW = yOLD. This updates only y(0:n-1).
   ------------------------------------------------------------ */
void bwipr1o(double *y, const int *perm, const double *p, const double *beta,
            const int m, const int n)
{
  int i, permi;
  double yi,t;

  if(n < 1)           /* If L = I, y remains the same */
    return;
/* ------------------------------------------------------------
   Let t = p(perm(n:m-1))' * y(perm(n:m-1))
   ------------------------------------------------------------ */
  for(t = 0.0, i = m-1; i >= n; i--){
    permi = perm[i];
    t += p[permi] * y[permi];
  }
/* ------------------------------------------------------------
   Solve yi for i = n-1:-1:0, from
   (eye(m)+triu(beta*p',1)) * yNEW = yOLD,
   i.e. yNEW(i) + betai*t = yOLD(i), with t := p(i+1:n-1)'*y.
   ------------------------------------------------------------ */
  for(; i >= 0; i--){
    permi = perm[i];
    yi = (y[permi] -= t * beta[i]);
    t += p[permi] * yi;
  }
}

/* ************************************************************
   PROCEDURE bwprodform - Solves (PROD_j L(pj,betaj))' * yNEW = yOLD.
   INPUT
     p - nonzeros of sparse m x n matrix P. Has xsuper(j+1) nonzeros in
      column j.
     xsuper - xsuper(j+1) is number of nonzeros in p(:,j).
     perm - lists pivot order for columns where ordered(j)==1.
     ordered - ordered[j]==1 iff p(:,j) and beta(L:,j) have been reordered;
       the original row numbers are in perm(:,j).
     n - number of columns in p, beta.
   UPDATED
     y - length m vector. On input, the rhs. On output the solution to
       (PROD_j L(pj,betaj))' * yNEW = yOLD.
   ************************************************************ */
void bwprodform(double *y, const int *xsuper, const int *perm,
                const double *p, const double *beta, const int *betajc,
                const char *ordered, const int n, int pnnz,
                int permnnz)
{
  int k,nk, mk;
/* ------------------------------------------------------------
   Backward solve L(pk,betak) * yNEXT = yPREV   for k=n-1:-1:0.
   ------------------------------------------------------------ */
  for(k = n-1; k >= 0; k--){
    mk = xsuper[k+1];
    nk = betajc[k+1] - betajc[k];
    pnnz -= mk;
    if(ordered[k]){
      permnnz -= mk;
      bwipr1o(y, perm+permnnz, p+pnnz, beta+betajc[k], mk, nk);
    }
    else
      bwipr1(y, p+pnnz, beta+betajc[k], mk, nk);
  }
  mxAssert(pnnz == 0,"");
  mxAssert(permnnz == 0 || permnnz == 1,"");
}

/* ============================================================
   MAIN: MEXFUNCTION
   ============================================================ */
/* ************************************************************
   PROCEDURE mexFunction - Entry for Matlab
   ************************************************************ */
void mexFunction(const int nlhs, mxArray *plhs[],
  const int nrhs, const mxArray *prhs[])
{
 const mxArray *L_FIELD;
 char *ordered;
 int m,n,nden, i,j, permnnz, nadd, pnnz;
 const double *beta, *betajcPr, *permPr, *b, *orderedPr, *xsuperPr,
   *pivpermPr, *p;
 int *betajc, *perm, *pivperm, *xsuper;
 double *y, *fwork;
/* ------------------------------------------------------------
   Check for proper number of arguments 
   ------------------------------------------------------------ */
  if(nrhs < NPARIN)
    mexErrMsgTxt("fwdpr1 requires more input arguments.");
  if(nlhs > NPAROUT)
    mexErrMsgTxt("fwdpr1 generates less output arguments.");
/* ------------------------------------------------------------
   Get input b
   ------------------------------------------------------------ */
  if(mxIsSparse(B_IN))
    mexErrMsgTxt("b should be full");
  m = mxGetM(B_IN);
  n = mxGetN(B_IN);
  b = mxGetPr(B_IN);
/* ------------------------------------------------------------
   Disassemble dense-update structure Lden (p,xsuper,beta,betajc,rowperm)
   ------------------------------------------------------------ */
  if(!mxIsStruct(LDEN_IN))
    mexErrMsgTxt("Parameter `Lden' should be a structure.");
  if( (L_FIELD = mxGetField(LDEN_IN,0,"betajc")) == NULL)      /* betajc */
    mexErrMsgTxt("Missing field Lden.betajc.");
  nden = mxGetM(L_FIELD) * mxGetN(L_FIELD) - 1;
  if(nden > 0){
    betajcPr = mxGetPr(L_FIELD);
    if( (L_FIELD = mxGetField(LDEN_IN,0,"p")) == NULL)            /* p */
      mexErrMsgTxt("Missing field Lden.p.");
    p = mxGetPr(L_FIELD);
    pnnz = mxGetM(L_FIELD) * mxGetN(L_FIELD);
    if( (L_FIELD = mxGetField(LDEN_IN,0,"pivperm")) == NULL)      /* pivperm */
      mexErrMsgTxt("Missing field Lden.pivperm.");
    pivpermPr = mxGetPr(L_FIELD);
    permnnz = mxGetM(L_FIELD) * mxGetN(L_FIELD);
    if( (L_FIELD = mxGetField(LDEN_IN,0,"xsuper")) == NULL)       /* xsuper */
      mexErrMsgTxt("Missing field Lden.xsuper.");
    nadd = mxGetM(L_FIELD) * mxGetN(L_FIELD) - 1;
    if(nadd > nden)
      mexErrMsgTxt("Size mismatch xsuper.");
    xsuperPr = mxGetPr(L_FIELD);
    if( (L_FIELD = mxGetField(LDEN_IN,0,"beta")) == NULL)          /* beta */
      mexErrMsgTxt("Missing field Lden.beta.");
    beta = mxGetPr(L_FIELD);
    if( (L_FIELD = mxGetField(LDEN_IN,0,"rowperm")) == NULL)      /* rowperm */
      mexErrMsgTxt("Missing field Lden.rowperm.");
    if(m != mxGetM(L_FIELD) * mxGetN(L_FIELD))
      mexErrMsgTxt("Size mismatch Lden.rowperm.");
    permPr = mxGetPr(L_FIELD);
    if( (L_FIELD = mxGetField(LDEN_IN,0,"dopiv")) == NULL)         /* dopiv */
      mexErrMsgTxt("Missing field Lden.dopiv.");
    if(mxGetM(L_FIELD) * mxGetN(L_FIELD) != nden)
      mexErrMsgTxt("Size mismatch Lden.dopiv.");
    orderedPr = mxGetPr(L_FIELD);
  }
/* ------------------------------------------------------------
   Create output y
   ------------------------------------------------------------ */
  Y_OUT = mxCreateDoubleMatrix(m, n, mxREAL);
  y = mxGetPr(Y_OUT);
  if(nden == 0)
    memcpy(y, b, m*n * sizeof(double));             /* if no dense cols */
  else{
/* ------------------------------------------------------------
   Allocate working arrays betajc(nden+1), perm(m), pivperm(permnnz),
   ordered(nden), xsuper(nden+1)
   ------------------------------------------------------------ */
    betajc = (int *) mxCalloc(nden + 1,sizeof(int));
    ordered = (char *) mxCalloc(nden,sizeof(char));
    perm = (int *) mxCalloc(MAX(m,1),sizeof(int));
    pivperm = (int *) mxCalloc(MAX(permnnz,1),sizeof(int));
    xsuper = (int *) mxCalloc(nden+1,sizeof(int));
/* ------------------------------------------------------------
   Allocate float working array fwork(m)
   ------------------------------------------------------------ */
    fwork = (double *) mxCalloc(MAX(m,1), sizeof(double));
/* ------------------------------------------------------------
   Convert xsuper to integer. Append xsuper[nadd] up to entry n.
   ------------------------------------------------------------ */
  for(i = 0; i <= nadd; i++){
    j = xsuperPr[i];
    xsuper[i] = --j;
  }
  while(i <= nden)
    xsuper[i++] = j;   /*  The phase-2 cols are all length xsuper[nadd]. */
/* ------------------------------------------------------------
   Convert betajcPr, ordered, permPr, pivperm to C-style
   ------------------------------------------------------------ */
    for(i = 0; i <= nden; i++){
      j = betajcPr[i];
      betajc[i] = --j;
    }
    for(i = 0; i < nden; i++){
      ordered[i] = orderedPr[i];
    }
    for(i = 0; i < permnnz; i++){
      pivperm[i] = pivpermPr[i];
    }
    L_FIELD = mxGetField(LDEN_IN,0,"beta");
    if(mxGetM(L_FIELD) * mxGetN(L_FIELD) != betajc[nden])
      mexErrMsgTxt("Size mismatch Lden.beta.");
    for(i = 0; i < m; i++){
      j = permPr[i];
      perm[i] = --j;
    }
/* ------------------------------------------------------------
   The actual job is done here: y(perm) = PROD_L'\b.
   ------------------------------------------------------------ */
    for(j = 0; j < n; j++){
      memcpy(fwork,b, m * sizeof(double));
      bwprodform(fwork, xsuper, pivperm, p, beta, betajc, ordered, nden,
                 pnnz, permnnz);
      for(i = 0; i < m; i++)            /* y(perm) = fwork */
        y[perm[i]] = fwork[i];
      y += m; b += m;
    }
/* ------------------------------------------------------------
   Release working arrays
   ------------------------------------------------------------ */
    mxFree(fwork);
    mxFree(xsuper);
    mxFree(pivperm);
    mxFree(perm);
    mxFree(ordered);
    mxFree(betajc);
  }
}
