/*
  [beta,ux,d,perm] = qrpfac(x,K);

 x should consist only of SDP part (length sum(K.s.^2)).

 qrpfac: Q*D*R factorization with pivoting
*/

#include <string.h>
#include <math.h>
#include "mex.h"
#include "blksdp.h"

#define BETA_OUT plhs[0]
#define UX_OUT plhs[1]
#define D_OUT plhs[2]
#define PERM_OUT plhs[3]

#define NPAROUT 4

#define X_IN prhs[0]
#define K_IN prhs[1]
#define NPARIN 2


double bwrealdot(const double *x, const double *y, const int n)
{
  int i;
  double r;

  for(r = 0.0, i=n; i > 0; r += x[i] * y[i])
    --i;
  return r;
}

/* ************************************************************
   PROCEDURE qrpivot - Q*D*R factorization for nxn matrix,
     with column pivotting. Yields R with 1 >= \|R(i,i+1:n)\| forall i.
   INPUT
     n - order of matrix to be factored
   UPDATED
     u - Full nxn. On input, u is matrix to be factored. On output,
       triu(u(:,perm),1) = uppertriangular factor, implicit unit-diagonal;
       tril(u(:,perm),0) = Householder reflections.
   OUTPUT
     beta - length n vector. kth Householder reflection is
        Qk = I-vk*vk' / beta[k],   where vk = u(k:n-1,perm[k]).
     d    - length n vector; X = Q * diag(d) * U, the positive diagonal,
        in decreasing order (for stability).
     perm - Column permutation: U=eye(n)+triu(u(:,perm)) is the U-factor.
   ************************************************************ */
void qrpivot(double *beta, double *u,double *d,int *perm, const int n)
{
  int i,j,k,imax, nmink, icol;
  double *uk, *rowuk;
  double dk, betak, ukui, v1;
/* ------------------------------------------------------------
   Initialize: d(j) = ssqr(x(:,j)) for j=1:n, perm = 0:n-1.
   ------------------------------------------------------------ */
  for(j = 0; j < n; j++)
    d[j] = realssqr(u,n);
  for(j = 0; j < n; j++)
    perm[j] = j;
/* ------------------------------------------------------------
   Pivot in step k=0:n-1 on imax:
   ------------------------------------------------------------ */
  for(k = 0; k < n; k++){
/* ------------------------------------------------------------
   Let [imax,dk] = max(d(k:m))
   ------------------------------------------------------------ */
    dk = d[k]; imax = k;
    for(i = k + 1; i < n; i++)
      if(d[i] > dk){
        imax = i;
        dk = d[i];
      }
/* ------------------------------------------------------------
   k-th pivot is j=perm[imax].
   ------------------------------------------------------------ */
    d[imax] = d[k];
    j = perm[imax];                     /* original node number */
    uk = u + j * n;
    rowuk = u + k;
    perm[imax] = perm[k];
    perm[k] = j;
/* ------------------------------------------------------------
    Let u(0:k-1,j)./d
    ------------------------------------------------------------ */
    realHadadiv(uk,uk, d, k);        /* uk(0:k-1) ./= d(0:k-1) */
/* ------------------------------------------------------------
   Store kth Householder reflection in v = x(k:n);
   dk = sqrt(dk), s = sign(xkk) * dk, v1 = xkk+s, betak = s*v1.
   ------------------------------------------------------------ */
    dk = sqrt(dk);
    uk += k;
    v1 = uk[0];
    betak = SIGN(v1) * dk;             /* use betak to store s */
    v1 += betak;
    betak *= v1;                       /* final betak */
    uk[0] = v1;
    beta[k] = betak;
/* ------------------------------------------------------------
   Reflect columns k+1:n-1, i.e.
   ui -= (uk'*ui / betak) * uk, where ui = u(k:n-1, perm[i]).
   ------------------------------------------------------------ */
    nmink = n-k;
    betak = -betak;
    for(i = k + 1; i < n; i++){
      icol = perm[i] * n;
      ukui = bwrealdot(uk, rowuk+icol, nmink);
      addscalarmul(rowuk+icol, ukui/betak, uk, nmink);
      d[icol] -= SQR(rowuk[icol]);
    }
  }
}

/* ============================================================
   MAIN: MEXFUNCTION
   ============================================================ */
/* ************************************************************
   PROCEDURE mexFunction - Entry for Matlab
   [beta,U,d,perm] = qrpfacK(x,K)
   ************************************************************ */
void mexFunction(const int nlhs, mxArray *plhs[],
  const int nrhs, const mxArray *prhs[])
{
  coneK cK;
  int i,k,nk,nksqr, sdplen,sdpdim;
  double *ux, *beta, *d, *permPr;
  int *iwork, *perm;
/* ------------------------------------------------------------
   Check for proper number of arguments
   ------------------------------------------------------------ */
  if(nrhs < NPARIN)
    mexErrMsgTxt("cholp requires more input arguments");
  if(nlhs > NPAROUT)
    mexErrMsgTxt("cholp produces less output arguments");
/* ------------------------------------------------------------
   Disassemble cone K structure
   ------------------------------------------------------------ */
  conepars(K_IN, &cK);
/* ------------------------------------------------------------
   Compute statistics: sdpdim = rdim+hdim, sdplen = rlen + hlen.
   ------------------------------------------------------------ */
  sdpdim = cK.rDim + cK.hDim;
  sdplen = cK.rLen + cK.hLen;
/* ------------------------------------------------------------
   Check input vector x.
   ------------------------------------------------------------ */
  if(mxGetM(X_IN) * mxGetN(X_IN) != sdpdim)
    mexErrMsgTxt("size mismatch x");
/* ------------------------------------------------------------
   Allocate output UX(sdpdim), beta(sdplen), d(sdplen), perm(sdplen),
   and let ux = x.
   ------------------------------------------------------------ */
  UX_OUT = mxCreateDoubleMatrix(sdpdim, 1, mxREAL);
  ux = mxGetPr(UX_OUT);
  memcpy(ux, mxGetPr(X_IN), sdpdim * sizeof(double));
  BETA_OUT =  mxCreateDoubleMatrix(sdplen, 1, mxREAL);
  beta = mxGetPr(BETA_OUT);
  D_OUT =  mxCreateDoubleMatrix(sdplen, 1, mxREAL);
  d = mxGetPr(D_OUT);
  PERM_OUT =  mxCreateDoubleMatrix(sdplen, 1, mxREAL);
  permPr = mxGetPr(PERM_OUT);
/* ------------------------------------------------------------
   Allocate working array iwork(sdplen)
   ------------------------------------------------------------ */
  iwork = (int *) mxCalloc(sdplen, sizeof(int));
  perm = iwork;
/* ------------------------------------------------------------
   The actual job is done here:
   ------------------------------------------------------------ */
  for(k = 0; k < cK.rsdpN; k++){                /* real symmetric */
    nk = cK.sdpNL[k];
    qrpivot(beta,ux,d,perm,nk);
    nksqr = SQR(nk);
    ux += nksqr;
    beta += nk; perm += nk; d += nk;
  }
  if(k < cK.sdpN)
    mexErrMsgTxt("complex x not yet supported");
/* ------------------------------------------------------------
   Convert "perm" to Fortran-index in doubles.
   ------------------------------------------------------------ */
  for(i = 0; i < sdplen; i++)
    permPr[i] = 1.0 + iwork[i];
/* ------------------------------------------------------------
   Release working array:
   ------------------------------------------------------------ */
  mxFree(iwork);
}
