/*
   y = vectril(x,K)
   For the PSD submatrices, we let Yk = diag(diag(Xk)) + tril(Xk+Xk',-1)
   Complex numbers are stored in Matlab's format (pr,pi).

 %  
 %   This file is part of SeDuMi 1.01e   (16JUN1998)
 %   Copyright (C) 1998 Jos F. Sturm
 %   CRL, McMaster University, Canada.
 %   Supported by the Netherlands Organization for Scientific Research (NWO).
 % 
 %   This program is free software; you can redistribute it and/or modify
 %   it under the terms of the GNU General Public License as published by
 %   the Free Software Foundation; either version 2 of the License, or
 %   (at your option) any later version.
 % 
 %   This program is distributed in the hope that it will be useful,
 %   but WITHOUT ANY WARRANTY; without even the implied warranty of
 %   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 %   GNU General Public License for more details.
 % 
 %   You should have received a copy of the GNU General Public License
 %   along with this program; if not, write to the Free Software
 %   Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 %

*/

#include <math.h>
#include "mex.h"
#include "blksdp.h"

#define Y_OUT plhs[0]
#define X_IN prhs[0]
#define K_IN prhs[1]

/* ------------------------------------------------------------
   Type definition:
   ------------------------------------------------------------ */
typedef struct{
 double *pr, *pi;
 int *jc, *ir;
    } jcirprpi;

/* ************************************************************
   PROCEDURE sptotril - For sparse x=vec(X), lets
     y = vec(diag(diag(X)) + tril(X+X',-1)).
   INPUT
     x - sparse vector, x.jc[0] can be nonzero.
     n - order of matrix, if we see x as a vectorized n x n matrix.
     startk - would be index of X(1,1) in vectorized form (to allow
       x as a subvector)
     xcpx - Flag. If true, then x is complex, and x.pi is imaginary part.
   UPDATED
     y - On input, has enough space allocated (nnz(y)<=nnz(x)) and y.jc[0]
       points to the point in y.{pr,pi,ir} where we should start writing.
       On output, contains vec(diag(diag(X)) + tril(X+X',-1)).
       NB: we don't change in y.jc.
   OUTPUT
     ynnz - number of nonzeros written in y.
   WORKING ARRAY
     iwork - 2 * (n+1) integer array.
   RETURNS
     number of nonzeros from x that we processed (i.e. belonging to nxn mat).
   ************************************************************ */
int sptotril(jcirprpi y,int *ynnz,jcirprpi x,const int n,const int startk,
	     int *iwork,bool xcpx)
{
  int i,j,k,inz,jnz,jstart,inext;
  int *xjc;
  double yij,iyij;

  xjc = iwork + (n+1);
  iyij = 0.0;
  /* ------------------------------------------------------------
     Let xjc[0:n] point to the columns of x, seen as an n x n matrix
     ------------------------------------------------------------ */
  j = 0;
  jstart = startk;
  inz = x.jc[0];
  xjc[0] = inz;
  for( ; (inz < x.jc[1]) && (j < n); inz++ ){
    i = x.ir[inz];
    for( ;(jstart <= i-n) && (j < n); jstart += n)
      xjc[++j] = inz;                          /* Move to column where i is */
  }
  while(j < n)
    xjc[++j] = inz;                            /* Close remaining columns */
  /* ------------------------------------------------------------
     Let iwork[0:n] = xjc[0:n].  Below, iwork[k] points to the part
     of column k that is at-and-below j, where j is the main loop counter.
     ------------------------------------------------------------ */
  memcpy(iwork,xjc,(n+1) * sizeof(int));
  /* ------------------------------------------------------------
     Now, we let Y be a lower triangular matrix with tril(X+X',1)
     as strict lower triangular, and diag(X) on its diagonal.
     ------------------------------------------------------------ */
  jnz = y.jc[0];
  for(j = 0, jstart = startk; j < n; j++, jstart += n){
    i = j;                                               /* Only i=j:n-1 */
    for(inz = iwork[j]; inz < xjc[j+1]; inz++){
      inext = x.ir[inz] - jstart;
      /* ------------------------------------------------------------
	 We're on x(inext,j), but first let y(i:inext-1,j) = x(j,i:inext-1)'.
	 ------------------------------------------------------------ */
      for(; i < inext; i++)
	if(iwork[i] < xjc[i+1])
	  if(x.ir[iwork[i]] == startk + i*n + j){
	    y.pr[jnz] = x.pr[iwork[i]];
	    if(xcpx)
	      y.pi[jnz] = -x.pi[iwork[i]];     /* y(i,j) = conj(x(j,i)) */
	    iwork[i]++;
	    y.ir[jnz++] = i+jstart;
	  }
      /* ------------------------------------------------------------
	 Let y(i,j) = RE x(i,j)              if i == j
	     y(i,j) = x(i,j) + conj(x(j,i))  if i > j
	 ------------------------------------------------------------ */
      yij = x.pr[inz];
      if(i > j){
	if(xcpx)
	  iyij = x.pi[inz];
	if(iwork[i] < xjc[i+1])
	  if(x.ir[iwork[i]] == startk + i*n + j){
	    yij += x.pr[iwork[i]];
	    if(xcpx)
	      iyij -= x.pi[iwork[i]];     /* y += conj(x(j,i)) */
	    iwork[i]++;
	  }
	if(xcpx)
	  y.pi[jnz] = iyij;
      }  /* below - diagonal */
      if(yij != 0.0 || iyij != 0.0){
        y.pr[jnz] = yij;
        y.ir[jnz++] = i+jstart;
      }
      i++;
    }
    /* ------------------------------------------------------------
       We've handled x(:,j), but we still have to add x(j,i+1:n-1)'
       ------------------------------------------------------------ */
      for(++i; i < n; i++)
	if(iwork[i] < xjc[i+1])
	  if(x.ir[iwork[i]] == startk + i*n + j){
	    y.pr[jnz] = x.pr[iwork[i]];
	    if(xcpx)
	      y.pi[jnz] = -x.pi[iwork[i]];     /* y(i,j) = conj(x(j,i)) */
	    iwork[i]++;
	    y.ir[jnz++] = i+jstart;
	  }
  }
  /* ------------------------------------------------------------
     Let *ynnz be the number of entries we put in y, and return
     the number of entries that we handled from x.
     ------------------------------------------------------------ */
  *ynnz = jnz - y.jc[0];
  return (inz - x.jc[0]);
}

/* ============================================================
   MAIN: MEXFUNCTION
   ============================================================ */
/* ************************************************************
   PROCEDURE mexFunction - Entry for Matlab
     y = vectril(x,K)

   For the PSD submatrices, we let Yk = diag(diag(Xk)) + tril(Xk+Xk',-1)
   Complex numbers are stored in Matlab's format (pr,pi).
   NB: x and y are sparse.
   ************************************************************ */
void mexFunction(const int nlhs, mxArray *plhs[],
  const int nrhs, const mxArray *prhs[])
{
 const mxArray *K_FIELD;
 int i, k, nk, yknnz,inz,iend,lqDim,lenfull,
   qDim,lpN,lorN,sdpN,sDim,sMaxn,startk;
 int xkjc[2];
 jcirprpi x,y,xk;
 bool xcpx;
 int *iwork;
 const double *sdpNL, *lorNL;

 /* ------------------------------------------------------------
    Check for proper number of arguments
    ------------------------------------------------------------ */
 if(nrhs < 2)
   mexErrMsgTxt("vectril requires 2 input arguments.");
 if(nlhs > 1)
   mexErrMsgTxt("vectril generates 1 output argument.");
 /* ------------------------------------------------------------
    Disassemble cone K structure
    ------------------------------------------------------------ */
 if(!mxIsStruct(K_IN))
   mexErrMsgTxt("Parameter `K' should be a structure.");
 if( (K_FIELD = mxGetField(K_IN,0,"l")) == NULL)      /* K.l */
   lpN = 0;
 else
   lpN = mxGetScalar(K_FIELD);
 if( (K_FIELD = mxGetField(K_IN,0,"q")) == NULL)      /* K.q , qDim*/
   qDim = 0;
 else{
   lorN = mxGetM(K_FIELD) * mxGetN(K_FIELD);
   lorNL = mxGetPr(K_FIELD);
   for(qDim = 0, i = 0; i < lorN; i++)
     qDim += lorNL[i];
 }
 if( (K_FIELD = mxGetField(K_IN,0,"s")) == NULL){     /* K.s */
   sdpN = 0; sDim = 0;
 }
 else{
   sdpN = mxGetM(K_FIELD) * mxGetN(K_FIELD);
   sdpNL = mxGetPr(K_FIELD);
   for(sDim = 0, i = 0; i < sdpN; i++)
     sDim += SQR(sdpNL[i]);
 }
 /* ------------------------------------------------------------
    Compute some statistics based on cone K structure
    ------------------------------------------------------------ */
 lqDim = lpN + qDim;
 lenfull = lqDim + sDim;
 /* ------------------------------------------------------------
    Get input vector x
    ------------------------------------------------------------ */
 if(mxGetM(X_IN) != lenfull || mxGetN(X_IN) != 1) 
   mexErrMsgTxt("Parameter `x' size mismatch.");
 if( !mxIsSparse(X_IN))
   mexErrMsgTxt("Parameter `x' should be sparse.");
 x.pr = mxGetPr(X_IN);
 if((xcpx = mxIsComplex(X_IN)))
   x.pi = mxGetPi(X_IN);
 x.jc = mxGetJc(X_IN);
 x.ir = mxGetIr(X_IN);
 /* ------------------------------------------------------------
    Allocate output vector y = sparse([],[],[],length(x),1,nnz(x))
    ------------------------------------------------------------ */
 if(xcpx){
   Y_OUT = mxCreateSparse(lenfull, 1, x.jc[1], mxCOMPLEX);
   y.pi = mxGetPi(Y_OUT);
 }
 else
   Y_OUT = mxCreateSparse(lenfull, 1, x.jc[1], mxREAL);
 y.pr = mxGetPr(Y_OUT);
 y.jc = mxGetJc(Y_OUT);
 y.ir = mxGetIr(Y_OUT);
 y.jc[0] = 0;
 if(x.jc[1] == x.jc[0])
   y.jc[1] = 0;
 else{  /* Only work to be done if x != []: */
   /* ------------------------------------------------------------
      Copy LP and LORENTZ part from x to y
      ------------------------------------------------------------ */
   iend = x.jc[1];
   if(x.ir[iend-1] < lqDim){                   /* If only LP and Lorentz */
     memcpy(y.pr,x.pr,iend * sizeof(double));
     memcpy(y.ir,x.ir,iend * sizeof(int));
     if(xcpx)
       memcpy(y.pi,x.pi,iend * sizeof(double));
     y.jc[1] = iend;
   }
   else{                                 /* Otherwise, first copy LP/Lorentz */
     inz = 0;
     for(i = x.ir[inz]; i < lqDim; i = x.ir[++inz])
       y.ir[inz] = i;
     if(inz > 0){
       memcpy(y.pr,x.pr,inz * sizeof(double));
       memcpy(y.ir,x.ir,inz * sizeof(int));
       if(xcpx)
	 memcpy(y.pi,x.pi,inz * sizeof(double));
     }
     /* ------------------------------------------------------------
	Allocate iwork[2*(MAX(nk)+1)]
	------------------------------------------------------------ */
     for(sMaxn = 0, i = 0; i < sdpN; i++)
       sMaxn = MAX(sMaxn, sdpNL[i]);
     iwork = (int *) mxCalloc(2 * (sMaxn + 1), sizeof(int));
     /* ------------------------------------------------------------
	Make PSD blocks Y = diag(diag(X)) + tril(X+X',-1)
	------------------------------------------------------------ */
     xk = x;
     xk.jc = xkjc;
     xk.jc[1] = x.jc[1];
     xk.jc[0] = inz;
     y.jc[0] = inz;
     startk = lqDim;
     for(k = 0; k < sdpN; k++){
       nk = sdpNL[k];
       xk.jc[0] += sptotril(y,&yknnz,xk,nk,startk,iwork,xcpx);
       y.jc[0] += yknnz;
       startk += SQR(nk);
     }
     y.jc[1] = y.jc[0];
     y.jc[0] = 0;
     /* ------------------------------------------------------------
	Release working array
	------------------------------------------------------------ */
     mxFree(iwork);
   }
 }
}
