/*
   [yq,ys] = invsqrtw(w,q,vlab,K)
   Computes w^(-1/2), where [vlab.^2,q] = eigK(w,K).

   NB: for PSD-cone, ys = (Q / diag(sqrt(vlab)))', so that YS'*YS = W^(-1/2).

    This file is part of SeDuMi 1.03BETA
    Copyright (C) 1999 Jos F. Sturm
    Dept. Quantitative Economics, Maastricht University, the Netherlands.
    Affiliations up to SeDuMi 1.02 (AUG1998):
      CRL, McMaster University, Canada.
      Supported by the Netherlands Organization for Scientific Research (NWO).
  
    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.
  
    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.
  
    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.

*/
#include <math.h>
#include <string.h>
#include "mex.h"
#include "triuaux.h"
#include "blksdp.h"

#define YQ_OUT myplhs[0]
#define YS_OUT myplhs[1]
#define NPAROUT 2

#define W_IN prhs[0]
#define Q_IN prhs[1]
#define V_IN prhs[2]
#define K_IN prhs[3]

#define NPARIN 4

/* ************************************************************
   PROCEDURE powminhalf - Computes y = w^{-1/2} for Lorentz,
     given v = sqrt(eig(w)).
   INPUT
     w - length n input vector.
     v - length 2 vector of spectral values, v = sqrt(eig(w)).
     n - order
   OUTPUT
     y - length n vector, y = w^{-1/2}.
   ************************************************************ */
void powminhalf(double *y, const double *w, const double *v, const int n)
{
  double fi,w1;
/* ------------------------------------------------------------
   Letting w1 = w[0] + sqrt(2) * det(v),
   we have y = [w1; -w2] / (trace(v) * det(v))
   ------------------------------------------------------------ */
  fi = v[0] * v[1];
  w1 = w[0] + M_SQRT2 * fi;
  fi *= v[0] + v[1];
  y[0] = w1 / fi;
  scalarmul(y+1, -1.0 / fi, w+1, n-1);
}

/* ************************************************************
   PROCEDURE qdivv - Computes Y = (Q / diag(sqrt(vlab)))'
   ************************************************************ */
void qdivv(double *y, const double *q,const double *v,const int n)
{
  int i,j,jcol,inz;
  double svi;

  for(inz = 0, i = 0; i < n; i++){
    svi = sqrt(v[i]);
    for(j = 0, jcol = i; j < n; j++, jcol += n)
      y[jcol] = q[inz++] / svi;              /* y(i,j) = q(j,i) / svi */
  }
}

/* ============================================================
   MAIN: MEXFUNCTION
   ============================================================ */
/* ************************************************************
   PROCEDURE mexFunction - Entry for Matlab
   ************************************************************ */
void mexFunction(const int nlhs, mxArray *plhs[],
  const int nrhs, const mxArray *prhs[])
{
  mxArray *myplhs[NPAROUT];
  int i,k, nk, nksqr, lenud, lendiag;
  double *yq,*ys;
  const double *q,*v,*w;
  coneK cK;
/* ------------------------------------------------------------
   Check for proper number of arguments 
   ------------------------------------------------------------ */
  if(nrhs < NPARIN)
    mexErrMsgTxt("qdivv requires 4 input arguments.");
  if(nlhs > NPAROUT)
    mexErrMsgTxt("qdivv generates 2 output arguments.");
/* ------------------------------------------------------------
   Disassemble cone K structure
   ------------------------------------------------------------ */
  conepars(K_IN, &cK);
/* ------------------------------------------------------------
   Get statistics of cone K structure
   ------------------------------------------------------------ */
  lenud = cK.rDim + cK.hDim;
  lendiag = cK.lpN + 2 * cK.lorN + cK.rLen + cK.hLen;
/* ------------------------------------------------------------
   Get inputs w, v, q
   ------------------------------------------------------------ */
  w = mxGetPr(W_IN) + cK.lpN;              /* skip LP part */
  if(mxGetM(V_IN) * mxGetN(V_IN) != lendiag)
    mexErrMsgTxt("v size mismatch");
  if(mxGetM(Q_IN) * mxGetN(Q_IN) != lenud)
    mexErrMsgTxt("q size mismatch");
  q = mxGetPr(Q_IN);
  v = mxGetPr(V_IN) + cK.lpN;    /* skip LP*/
/* ------------------------------------------------------------
   Allocate outputs Yq, Ys
   ------------------------------------------------------------ */
  YQ_OUT =  mxCreateDoubleMatrix(cK.qDim, 1, mxREAL);
  yq = mxGetPr(YQ_OUT);
  YS_OUT =  mxCreateDoubleMatrix(lenud, 1, mxREAL);
  ys = mxGetPr(YS_OUT);
/* ------------------------------------------------------------
   The actual job is done here: (yq,ys) = w^{-1/2}
   ------------------------------------------------------------ */
  for(k = 0; k < cK.lorN; k++){               /* LORENTZ */
    nk = cK.lorNL[k];
    powminhalf(yq, w,v,nk);
    yq += nk; w += nk; v+=2;
  }
  for(k = 0; k < cK.rsdpN; k++){                /* PSD: real symmetric */
    nk = cK.sdpNL[k];
    qdivv(ys, q,v,nk);
    nksqr = SQR(nk);
    ys += nksqr; q += nksqr; v += nk;
  }
  if(k < cK.sdpN){                    /* complex Hermitian */
    mexErrMsgTxt("Complex q not yet supported.");
  }
/* ------------------------------------------------------------
   Copy requested output parameters (at least 1), release others.
   ------------------------------------------------------------ */
  i = MAX(nlhs, 1);
  memcpy(plhs,myplhs, i * sizeof(mxArray *));
  for(; i < NPAROUT; i++)
    mxDestroyArray(myplhs[i]);
}
