/*
 mex getada.c spscale.c blkaux.c
*/
/* ************************************************************
   function [ADA,ddota] = getada(A,d,dsqr,detd,udsqr,K)
   Compute ADA(i,j) = (D(d^2)*A.t(:,i))' *A.t(:,j),
   and exploit sparsity as much as possible.

    This file is part of SeDuMi 1.03BETA
    Copyright (C) 1999 Jos F. Sturm
    Dept. Quantitative Economics, Maastricht University, the Netherlands.
    Affiliations up to SeDuMi 1.02 (AUG1998):
      CRL, McMaster University, Canada.
      Supported by the Netherlands Organization for Scientific Research (NWO).
  
    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.
  
    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.
  
    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.

   ************************************************************ */

#include <string.h>
#include "mex.h"
#include "blksdp.h"

#define ADA_OUT myplhs[0]
#define DDOTA_OUT myplhs[1]
#define NPAROUT 2

#define A_IN prhs[0]
#define D_IN  prhs[1]
#define DSQR_IN  prhs[2]
#define DETD_IN  prhs[3]
#define UDSQR_IN prhs[4]
#define K_IN  prhs[5]
#define NPARIN 6

/* ========================= G E N E R A L ========================= */

/* ************************************************************
   PROCEDURE exmerge - mergeing 2 exclusive, increasing integer arrays.
   INPUT
     x - length nx array, increasing entries.
     y - length ny array, its entries are increasing, and do not occur in x.
     nx,ny - order of x and y.
   OUTPUT
     z - length nx+ny vector
   ************************************************************ */
void exmerge2(int *z, const int *x, const int nx, const int lastx,
              const int *y, const int ny)
{
  int inz, i, j, xi, yj;
  inz = 0;
  i = 0; j = 0;
  xi = x[0]; yj = y[0];
/* ------------------------------------------------------------
   lastx < lasty, so we have to merge up to lastx, after that just copy.
   ------------------------------------------------------------ */
  while(yj < lastx){
    for(; xi < yj; xi = x[++i])
      z[inz++] = xi;                        /* from x to z */
    mxAssert(yj < xi,"");       /* exclusive: !(xi<yj) => xi>yj */
    for(; yj < xi; yj = y[++j])          /* yj < xi <= lastx */
      z[inz++] = yj;                         /* from y to z */
    mxAssert(xi < yj,"");       /* exclusive: !(xi>yj) => xi<yj */
  }
/* ------------------------------------------------------------
   Now xi <= lastx < yj. Copy remaining x's  and succeeding y's.
   ------------------------------------------------------------ */
  memcpy(z+inz, x+i, (nx-i) * sizeof(int));
  inz += (nx-i);
  memcpy(z+inz, y+j, (ny-j) * sizeof(int));
}

void exmerge(int *z, const int *x, const int nx, const int *y, const int ny)
{
  int lastx, lasty;
  if(nx == 0)
    memcpy(z, y, ny*sizeof(int));
  else if(ny == 0)
    memcpy(z, x, nx*sizeof(int));
  else{
    lastx = x[nx - 1];
    lasty = y[ny - 1];
    if(lastx < lasty)
      exmerge2(z, x,nx,lastx, y,ny);
    else
      exmerge2(z, y,ny,lasty, x,nx);
  }
}

/* ************************************************************
   PROCEDURE spzeros -- Set z = 0, where z is known to have nonzeros
      only in blocks yblk(:).  The structure of the blocks is in y.
   INPUT
     various parameters that describe which positions in z
     need to be re-initialized after their last use in spscaleK.
   UPDATED
     z - the entries described by y,yblk,dz will be made 0.0.
   ************************************************************ */
void spzeros(double *z, const int *yjc,const int *yir,
	     const int *yblkir,const int yblknnz,
	     const int firstPSD, const int *dzjc,
             const int *dzir, const int *blkstart)
{
  int inz,k,jnz, i,nextblksub;
/* ------------------------------------------------------------
   LP and LORENTZ: re-initialize z[i]=0 for i listed in yir.
   ------------------------------------------------------------ */
  for(inz = yjc[0]; inz < yjc[firstPSD]; inz++)
    z[yir[inz]] = 0.0;
/* ------------------------------------------------------------
   PSD: only subscripts listed in dzjc, dzir, corresponding to
   blocks listed in yblk.
   ------------------------------------------------------------ */
  dzjc -= firstPSD;
  for(inz = firstPSD; inz < yblknnz; inz++){
    k = yblkir[inz];               /* PSD block number */
    nextblksub = blkstart[k+1];    /* first subscript outside this block */
    if(dzir[dzjc[inz+1]-1] < nextblksub)   /* column is subset of this block*/
      for(jnz = dzjc[inz]; jnz < dzjc[inz+1]; jnz ++)
        z[dzir[jnz]] = 0.0;
    else                          /* block ends where subscripts too big */
      for(i = dzir[(jnz = dzjc[inz])]; i < nextblksub; i=dzir[++jnz])
        z[i] = 0.0;
  }
}

/* ************************************************************
   PROCEDURE sptriu2sym -- Copy upper to lower triangular in
      real sparse square matrix.
   INPUT
     m - number of columns in ada
   UPDATED
     x - on input, contains sparsity structure, and the values
         of the upper triangular. On return, x = x + triu(x,1)'.
   WORK
     iwork - length m array of integers. Points to "below row j"
       part of columns (trilstart). (initial contents irrelevant)
   ************************************************************ */
void sptriu2sym(jcir x,const int m,int *iwork)
{
  int j, inz, jend;
  
  /* ------------------------------------------------------------
     Initialize: let iwork(0:m-1) = ada.jc(0:m-1)
     ------------------------------------------------------------ */
  memcpy(iwork, x.jc, m * sizeof(int));   /* don't copy x.jc[m] */
  /* ------------------------------------------------------------
     For each column j:   for each index i > j:
     let x(i,j) = x(j,i) and let iwork point to next nonzero in col i
     Guard: x.ir[iwork(i)] >= j for all i >= j.
     ------------------------------------------------------------ */
  for(j = 0; j < m; j++){
    jend = x.jc[j+1];
    for(inz = iwork[j]; inz < jend; inz ++)
      x.pr[inz] = x.pr[ iwork[x.ir[inz]]++ ];
  }
}

/* ************************************************************
   PROCEDURE: getada
   INPUT
     ada.{jc,ir} - sparsity structure of ada, after perm.
   OUTPUT
     ada.pr - ada(i,j) = ai'*D(d^2)*aj, where aj = At(:,perm[j]).
     ddota  - sparse lorN * m matrix, ddota(k,j) = d_k'*aj_k, where
              "_k" denotes the k-th Lorentz block.
   WORKING ARRAYS
     fwork -
     iwork -
   ************************************************************ */
void getada(jcir ada, jcir ddota, jcir At,
            const double *d, const double *dsqr, const double *udsqr,
            const double *detd, const int *Ablkjc, const int *Ablkir,
            const int *dzstructjc, const int *dzstructir, const int *blkstart,
            const int *blkNL, const int *denlor, const int *perm,
            const int ndenlor,const int m, const int nblk, const int lenfull,
            const coneK cK,
            double *fwork, int *iwork)
{
  int i,j,knz,inz, ddotaNz, blknnzj, dznnz, addnnz,
    permi, permj, firstPSD;
  double *daj, *ddotaj;
  double adaij;
  int *dzjc, *dzir;
  jcir aj;

/* ------------------------------------------------------------
   Partition working arrays
   int: dzjc(cK.sdpN+1), dzir(dzstructjc[m]), aj.jc(nblk+1), iwork[iwsiz],
     with iwsiz = max(m,max(nk(PSD)),dzstrucjc[m]).
   double:    daj(lenfull), ddotaj(lorN), fwork[fwsiz],
     with fwsiz = max(nk(REAL-PSD)^2, 2*nk(HERM)^2).
   ------------------------------------------------------------ */
  dzjc = iwork;                             /* cK.sdpN+1 */
  dzir = (iwork += cK.sdpN+1);              /* dzstructjc[m] */
  aj.jc = (iwork += dzstructjc[m]);         /* nblk+1 */
  iwork += nblk + 1;
  daj    = fwork;                           /* lenfull */
  ddotaj = (fwork += lenfull);              /* lorN */
  fwork += cK.lorN;
/* ------------------------------------------------------------
   Set ddotaj = zeros(cK.lorN,1)
   ------------------------------------------------------------ */
  fzeros(ddotaj, cK.lorN);
/* ------------------------------------------------------------
   Initialize dznnz = 0, meaning dz=[]. Later we will merge
   columns from dzstruct, with dz, and partition into selected blocks.
   ------------------------------------------------------------ */
  dznnz = 0;
  ddotaNz = 0;
  aj.pr = At.pr;                  /* init to At(:,0) */
  aj.ir = At.ir;
  for(j = 0; j < m; j++){
    ddota.jc[j] = ddotaNz;
    inz = ada.jc[j];
    if(inz < ada.jc[j+1]){           /* if aj is not all zero */
      permj = perm[j];
      mxAssert(At.jc[permj] < At.jc[permj+1],"");
/* ------------------------------------------------------------
   Partition  aj into (l,q,s) blocks, i.e.:
   1) Let blknnzj = # nonzero blocks in aj
   2) ajc[k] = position of these blocks in aj.ir and aj.pr (== At.pr/ir).
   ------------------------------------------------------------ */
      aj.jc[0] = At.jc[permj];
      blknnzj = Ablkjc[permj+1] - Ablkjc[permj];
      vec2selblks(aj.jc, blkstart, aj.ir, At.jc[permj], At.jc[permj+1],
                  Ablkir + Ablkjc[permj], blknnzj);
      aj.jc[blknnzj] = At.jc[permj+1];         /* close the block partition */
/* ------------------------------------------------------------
   Let firstPSD be the 1st nonzero PSD block in At(:,permj).
   (is blknnzj if no PSD blocks)
   ------------------------------------------------------------ */
      firstPSD = 0;
      intbsearch(&firstPSD, Ablkir + Ablkjc[permj], blknnzj, 1 + cK.lorN);
/* ------------------------------------------------------------
   Make dzjc,dzir: the PSD-nonzero locations, with pointers
   to the selected PSD blocks. nz-locs = merge(dzir,dzstruct(:,j)).
   ------------------------------------------------------------ */
      memcpy(iwork, dzir, dznnz * sizeof(int));
      addnnz = dzstructjc[j+1] - dzstructjc[j];
      exmerge(dzir, iwork, dznnz, dzstructir+dzstructjc[j], addnnz);
      dznnz += addnnz;
      vec2selblks(dzjc, blkstart, dzir, 0, dznnz,
                  Ablkir + Ablkjc[permj] + firstPSD, blknnzj - firstPSD);
      dzjc[blknnzj-firstPSD] = dznnz;         /* close the block partition */
/* ------------------------------------------------------------
   Compute daj = D(d^2)*aj.
   ------------------------------------------------------------ */
      ddotaNz += spscaleK(daj, dzjc, dzir, aj,
                          Ablkir + Ablkjc[permj],blknnzj,
                          blkstart,
                          d, dsqr, detd, udsqr, cK, blkNL,
                          fwork,iwork,
                          ddota.ir + ddotaNz, ddota.pr+ddotaNz);
/* ------------------------------------------------------------
   Let ddotaj = ddota(:,j) in full
   ------------------------------------------------------------ */
      for(i = ddota.jc[j]; i < ddotaNz; i++)
        ddotaj[ddota.ir[i]] = ddota.pr[i];
/* ------------------------------------------------------------
   Remove dense Lorentz blocks from ddotaj, but keep them in the
   output matrix ddota: they'll be used in mex-function adenscale().
   ------------------------------------------------------------ */
      for(i = 0; i < ndenlor; i++)
        ddotaj[denlor[i]] = 0.0;
/* ------------------------------------------------------------
   For i=1..j-1, let  ada_ij = a_i'*daj + ddota_i'*ddotaj.
   Recall that inz = ada.jc[j].
   ------------------------------------------------------------ */
      for(i = ada.ir[inz]; i < j; i = ada.ir[++inz]){     /* ADA(i,j), i<j */
        permi = perm[i];
        for(adaij = 0.0, knz = At.jc[permi]; knz < At.jc[permi+1]; knz++)
          adaij +=  At.pr[knz] * daj[At.ir[knz]];
        for(knz = ddota.jc[i]; knz < ddota.jc[i+1]; knz++)
          adaij +=  ddota.pr[knz] * ddotaj[ddota.ir[knz]];
        ada.pr[inz] = adaij;
      }
/* ------------------------------------------------------------
   Same for the diagonal entry ADA(j,j). We will still use
   ddotaj (instead of merely ddota), since we may need to wipe
   dense Lorentz blocks.
   ------------------------------------------------------------ */
      mxAssert(i==j,"");                                      /* ADA(j,j) */
      for(adaij = 0.0, knz = At.jc[permj]; knz < At.jc[permj+1]; knz++)
        adaij +=  At.pr[knz] * daj[At.ir[knz]];
      for(knz = ddota.jc[j]; knz < ddotaNz; knz++)
        adaij +=  ddota.pr[knz] * ddotaj[ddota.ir[knz]];
      ada.pr[inz] = adaij;
/* ------------------------------------------------------------
   Re-initialize daj = 0, ddotaj = 0.
   ------------------------------------------------------------ */
      spzeros(daj, aj.jc, aj.ir, Ablkir+Ablkjc[permj],blknnzj,
              firstPSD, dzjc,dzir, blkstart);
      for(i = ddota.jc[j]; i < ddotaNz; i++)
        ddotaj[ddota.ir[i]] = 0.0;
    } /* aj not all-0 */
  }
  ddota.jc[m] = ddotaNz;
/* ------------------------------------------------------------
   Copy upper triangular of ADA to lower, so that it gets symmetric.
   ------------------------------------------------------------ */
  sptriu2sym(ada,m,iwork);
}

/* ============================================================
   MEXFUNCTION
   ============================================================ */
/* ************************************************************
   PROCEDURE mexFunction - Entry for Matlab
   ************************************************************ */
void mexFunction(const int nlhs, mxArray *plhs[],
                 const int nrhs, const mxArray *prhs[])
{
  mxArray *myplhs[NPAROUT];
  coneK cK;
  const mxArray *K_FIELD, *A_FIELD;
  int lenfull, nblk, lend, lenudsqr, ndenlor;
  const double *d, *dsqr, *udsqr, *detd, *permPr;
  const int *Ablkjc, *Ablkir, *dzstructjc, *dzstructir, *blkstart;
  int m, i, j, inz, knz, fwsiz, iwsiz;
  double *fwork, *denlorPr;
  int *iwork, *blkNL, *denlor, *perm;
  jcir At, ddota, ada;
/* ------------------------------------------------------------
   Check for proper number of arguments
   ------------------------------------------------------------ */
  if(nrhs < NPARIN)
    mexErrMsgTxt("getADA requires more input arguments.");
  if(nlhs > NPAROUT)
    mexErrMsgTxt("getADA produces less output arguments.");
/* ------------------------------------------------------------
   Disassemble cone K structure
   ------------------------------------------------------------ */
  conepars(K_IN, &cK);
/* ------------------------------------------------------------
   Compute some statistics based on cone K structure
   ------------------------------------------------------------ */
  lenfull = cK.lpN +  cK.qDim + cK.rDim + cK.hDim;
  nblk = 1 + cK.lorN + cK.sdpN;
  lend = cK.lpN +  cK.qDim;               /* only Lorentz part needed */
  lenudsqr = cK.rDim + cK.hDim;       /* for PSD */
/* ------------------------------------------------------------
   Get INPUT: d, dsqr, detd, udsqr, denlorPr=A.dense.qs.
   ------------------------------------------------------------ */
  if(mxGetM(D_IN) * mxGetN(D_IN) < lend)                 /* d */
    mexErrMsgTxt("Size d mismatch");
  d = mxGetPr(D_IN);
  if(mxGetM(DSQR_IN) * mxGetN(DSQR_IN) < cK.lpN)         /* dsqr */
    mexErrMsgTxt("Size dsqr mismatch");
  dsqr = mxGetPr(DSQR_IN);
  if(mxGetM(DETD_IN) * mxGetN(DETD_IN) != cK.lorN)       /* detd */
    mexErrMsgTxt("Size detd mismatch");
  detd = mxGetPr(DETD_IN);
  if(mxGetM(UDSQR_IN) * mxGetN(UDSQR_IN) != lenudsqr)    /* udsqr */
    mexErrMsgTxt("udsqr size mismatch.");
  udsqr = mxGetPr(UDSQR_IN);
  if(!mxIsStruct(A_IN))
    mexErrMsgTxt("Parameter `A' should be a structure.");
  if( (A_FIELD = mxGetField(A_IN,0,"dense")) == NULL)      /* A.dense */
    mexErrMsgTxt("Missing A.dense");
  if(!mxIsStruct(A_FIELD))
    mexErrMsgTxt("`A.dense' should be a structure.");
  if( (K_FIELD = mxGetField(A_FIELD,0,"qs")) == NULL)      /* A.dense.qs */
    mexErrMsgTxt("Missing A.dense.qs.");
  ndenlor = mxGetM(K_FIELD) * mxGetN(K_FIELD);
  denlorPr = mxGetPr(K_FIELD);
/* ------------------------------------------------------------
   Get INPUT: Ablk, dzstruct, blkstart (from K)
   ------------------------------------------------------------ */
  if( (K_FIELD = mxGetField(K_IN,0,"ABLK")) == NULL)      /* K.ABLK */
    mexErrMsgTxt("Missing K.ABLK.");
  Ablkjc = mxGetJc(K_FIELD);
  Ablkir = mxGetIr(K_FIELD);
  if( (K_FIELD = mxGetField(K_IN,0,"DZSTRUCT")) == NULL)   /* K.DZSTRUCT */
    mexErrMsgTxt("Missing K.DZSTRUCT.");
  dzstructjc = mxGetJc(K_FIELD);
  dzstructir = mxGetIr(K_FIELD);
  if( (K_FIELD = mxGetField(K_IN,0,"blkstart"))==NULL)      /*K.blkstart*/
    mexErrMsgTxt("Missing K.blkstart.");
  if(!mxIsSparse(K_FIELD))
    mexErrMsgTxt("K.blkstart must be a sparse matrix.");
  blkstart = mxGetIr(K_FIELD);
/* ------------------------------------------------------------
   Get INPUT: A.t, A.perm
   ------------------------------------------------------------ */
  if( (A_FIELD = mxGetField(A_IN,0,"t")) == NULL)          /* A.t */
    mexErrMsgTxt("Missing A.t");
  m = mxGetN(A_FIELD);
  At.pr = mxGetPr(A_FIELD);
  At.jc = mxGetJc(A_FIELD);
  At.ir = mxGetIr(A_FIELD);
  if( (A_FIELD = mxGetField(A_IN,0,"perm")) == NULL)          /* A.perm */
    mexErrMsgTxt("Missing A.perm");
  permPr = mxGetPr(A_FIELD);
/* ------------------------------------------------------------
   Allocate output matrix ADA based on K.SYMBADA
   ------------------------------------------------------------ */
  if( (K_FIELD = mxGetField(K_IN,0,"SYMBADA")) == NULL)  /* K.SYMBADA */
    mexErrMsgTxt("Missing K.SYMBADA.");
  if(mxGetN(K_FIELD) != m)
    mexErrMsgTxt("Size At mismatches K.SYMBADA.");
  ada.jc = mxGetJc(K_FIELD);
  ada.ir = mxGetIr(K_FIELD);
  ADA_OUT = mxCreateSparse(m,m, ada.jc[m],mxREAL);
  ada.pr = mxGetPr(ADA_OUT);
  memcpy(mxGetJc(ADA_OUT), ada.jc, (m+1) * sizeof(int));
  memcpy(mxGetIr(ADA_OUT), ada.ir, ada.jc[m] * sizeof(int));
/* ------------------------------------------------------------
   ALLOCATE integer working arrays:
   blkNL[nblk], denlor(ndenlor), perm(m),
   iwsiz = cK.sdpN+1+dzstructjc[m]+nblk+1 + max(m,max(nk(PSD)),dzstrucjc[m]),
   iwork[iwsiz]
   ------------------------------------------------------------ */
  blkNL = (int *) mxCalloc(nblk, sizeof(int));
  denlor = (int *) mxCalloc(ndenlor, sizeof(int));
  perm = (int *) mxCalloc(m, sizeof(int));
  iwsiz = cK.sdpN+1 + dzstructjc[m] + nblk+1
    + MAX(MAX(m,MAX(cK.rMaxn,cK.hMaxn)),dzstructjc[m]);
  iwork = (int *) mxCalloc(iwsiz, sizeof(int));
/* ------------------------------------------------------------
   ALLOCATE float working array:
   fwsiz = lenfull + lorN + max(nk(REAL-PSD)^2, 2*nk(HERM)^2),
   fwork[fwsiz].
   ------------------------------------------------------------ */
  fwsiz = lenfull + cK.lorN + MAX(SQR(cK.rMaxn),2*SQR(cK.hMaxn));
  fwork  = (double *) mxCalloc(fwsiz, sizeof(double));
/* ------------------------------------------------------------
   denlor and perm to integer C-style
   ------------------------------------------------------------ */
  for(i = 0; i < ndenlor; i++){
    j = denlorPr[i];
    denlor[i] = --j;
  }
  for(i = 0; i < m; i++){
    j = permPr[i];
    perm[i] = --j;
  }
/* ------------------------------------------------------------
   Let blkNL = [K.l, K.q, K.s] in integer.
   ------------------------------------------------------------ */
  blkNL[0] = cK.lpN;
  knz = 1;
  for(inz = 0; inz < cK.lorN; inz++)
    blkNL[knz+inz] = cK.lorNL[inz];
  knz += cK.lorN;
  for(inz = 0; inz < cK.sdpN; inz++)
    blkNL[knz+inz] = cK.sdpNL[inz];
/* ------------------------------------------------------------
   Create OUTPUT  ddota, a sparse(lorN,m) matrix for the nonzero Lorentz blocks
   in At,  to store di'*aij.  (i<lorN, j<m)
   ------------------------------------------------------------ */
  for(inz = 0, knz = 0; inz < Ablkjc[m]; inz++)   /* count no. Lorentz */
    knz += ((Ablkir[inz] <= cK.lorN) && (Ablkir[inz] > 0));
  DDOTA_OUT = mxCreateSparse(cK.lorN,m, MAX(1,knz),mxREAL);
  ddota.pr = mxGetPr(DDOTA_OUT);
  ddota.jc = mxGetJc(DDOTA_OUT);
  ddota.ir = mxGetIr(DDOTA_OUT);
/* ------------------------------------------------------------
   ACTUAL COMPUTATION: handle constraint aj=At(:,perm(j)), j=0:m-1.
   ------------------------------------------------------------ */
  getada(ada, ddota, At, d, dsqr, udsqr, detd, Ablkjc, Ablkir,
         dzstructjc, dzstructir, blkstart, blkNL, denlor, perm,
         ndenlor, m, nblk, lenfull, cK, fwork, iwork);
  mxAssert(MAX(1,ddota.jc[m]) == mxGetNzmax(DDOTA_OUT),"");
/* ------------------------------------------------------------
   RELEASE WORKING ARRAYS.
   ------------------------------------------------------------ */
  mxFree(fwork);
  mxFree(iwork);
  mxFree(perm);
  mxFree(denlor);
  mxFree(blkNL);
/* ------------------------------------------------------------
   Copy requested output parameters (at least 1), release others.
   ------------------------------------------------------------ */
  i = MAX(nlhs, 1);
  memcpy(plhs,myplhs, i * sizeof(mxArray *));
  for(; i < NPAROUT; i++)
    mxDestroyArray(myplhs[i]);
}
