#include "bfgs.h"
#include <math.h>
#include "vecutil.h"
#include "warning.h"
#include "options.h"

BFGS::BFGS(Uint mxit, Uint ndims, Uint nconstr, Double geps,
           Double ceps, Double sig, bool cent) :
    maxit(mxit),
    n(ndims),
    nconstraints(nconstr),
    nsetconstraints(0),
    pos(0),
    gradeps(geps),
    conveps(ceps),
    sigma(sig),
    eps(1.0e-10),
    central(cent),
    nfun(0),
    ngrad(0) {
  assertcond(n == 2 || n == 3, "Illegal n in BFGS");
  NEWVEC(Double, b, nconstraints);
  NEWVEC(Uint, iact, nconstraints);
  NEWMAT(Double, B, n, n);
  NEWMAT(Double, A, nconstraints, n);
  NEWMAT(Double, Z, n, n); // rows span null-space of A'
  NEWMAT(Double, Zold, n, n);
  NEWMAT(Double, W, 2*n, n);
  NEWVEC(bool, active, n);
}

void crossprod(DoubleVec res, DoubleVec a, DoubleVec b) {
  res[0] =  (a[1]*b[2] - a[2]*b[1]);
  res[1] = -(a[0]*b[2] - a[2]*b[0]);
  res[2] =  (a[0]*b[1] - a[1]*b[0]);
}

void BFGS::addconstraint(DoubleMat c) {
  assertinternal(n == 2 || n == 3);
  assertinternal(nsetconstraints < nconstraints);
  if (n == 2) {
    A[nsetconstraints][0] =  (c[0][1] - c[1][1]);
    A[nsetconstraints][1] = -(c[0][0] - c[1][0]);
  }
  else { // n >= 3
    Double bminusa[3];
    Double cminusb[3];
    for (Uint i = 0; i < 3; i++) {
      bminusa[i] = c[1][i] - c[0][i];
      cminusb[i] = c[2][i] - c[1][i];
    }
    crossprod(A[nsetconstraints], bminusa, cminusb);
  }
  double nc = sqrt(dotprod(A[nsetconstraints], A[nsetconstraints], n));
  for (Uint i = 0; i < n; i++) A[nsetconstraints][i] /= nc;
  
  b[nsetconstraints] = dotprod(A[nsetconstraints], c[0], n);
  // constraints are A[i]'*x <= b[i]
  nsetconstraints++;
}

void BFGS::addconstraint(DoubleVec a, Double bb) {
  assertinternal(nsetconstraints < nconstraints);
  copyvec(A[nsetconstraints], a, n);
  b[nsetconstraints] = bb;
  nsetconstraints++;
}

BFGS::~BFGS() {
  DELETEVEC(b);
  DELETEVEC(iact);
  DELETEMAT(B);
  DELETEMAT(A);
  DELETEMAT(Z);
  DELETEMAT(Zold);
  DELETEMAT(W);
}

void BFGS::chol(DoubleMat C, DoubleVec x, DoubleVec b, Uint n) { 
  // solve C*x = b with Choleski (can have W=C)
  Double tot;
  DoubleVec p;
  NEWVEC(Double, p, 3);
  if (W != C)
    for (Uint i = 0; i < n; i++)
      copyvec(W[i], C[i], n);
  for (Uint i = 0; i < n; i++) {
    for (Uint j = i; j < n; j++) {
      int k;
      for (tot = W[i][j], k = i-1; k >= 0; k--) 
        tot -= W[i][k]*W[j][k];
      if (i == j) {
        assertcond(tot > 0.0, "Choleski failed in BFGS");
        p[i] = sqrt(tot);
      } 
      else W[j][i] = tot/p[i];
    }
  }
  for (Uint i = 0; i < n; i++) {
    int k;
    for (tot = b[i], k = i-1; k >= 0; k--) tot -= W[i][k]*x[k];
    x[i] = tot/p[i];
  }
  for (int i = n - 1; i >= 0; i--) {
    Uint k;
    for (tot = x[i], k = i + 1; k < n; k++) tot -= W[k][i]*x[k];
    x[i] = tot/p[i];
  }
  DELETEVEC(p);
}

void BFGS::findlambda(DoubleVec lambda, DoubleVec g) {
  DoubleVec Atg;
  NEWVEC(Double, Atg, nconstraints);
  for (Uint i = 0; i < nactive; i++) {
    Atg[i] = dotprod(A[iact[i]], g, n);
    for (Uint j = 0; j < nactive; j++) 
      W[i][j] = dotprod(A[iact[i]], A[iact[j]], n);
  }
  chol(W, lambda, Atg, nactive);
  DELETEVEC(Atg);
}

void BFGS::householder(DoubleMat Z, Uint firstcol, Uint ncol) {
  assertinternal(n >= nactive);
  for (Uint i = 0; i < nactive; i++) copyvec(W[i], A[iact[i]], n);
  for (Uint i = 0; i < n; i++) {
    zero(W[nactive + i], n);
    W[nactive + i][i] = 1.0; // W = [A:I]
  }
  for (Uint k = 0; k < nactive; k++) {
    DoubleVec y = W[k] + k;  // points to W_{kk}
    Double sgn = (y[0] < 0 ? -1 : 1);
    Double norm = sqrt(dotprod(y, y, n - k));
    Double alfa = 1.0/((norm + fabs(y[0]))*norm);
    y[0] += sgn*norm;
    for (Uint j = k + 1; j < nactive + n; j++) {
      Double beta = alfa*dotprod(y, W[j] + k, n - k);
      axpy(W[j] + k, -beta, y, W[j] + k, n - k);
    }
  }
  for (Uint i = 0; i < ncol; i++)
    for (Uint j = 0; j < n; j++)
      Z[i][j] = W[nactive + j][i + firstcol];
}

Double BFGS::maxstep(DoubleVec x, DoubleVec s, int &imax) {
    // find step from x along s to nearest polygon side
  Double maxs = 1.0e99;
  for (Uint i = 0; i < nconstraints; i++) 
    if (!active[i]) {
      Double d = dotprod(A[i], s, n);
      if (d < 0) {
        Double stp = (b[i] - dotprod(A[i], x, n))/d;
        if (stp < maxs) {maxs = stp; imax = i;}
      }
    }
  assertinternal(maxs < 0.9e99);
//  if (maxs < fabs(eps)) return 0;
  assertcond(maxs > 0, "Maxstep nonpositive in BFGS");
  return maxs;
}

void BFGS::addconstraint(Uint i) {
  for(Uint j = 0; j < nred; j++) copyvec(Zold[j], Z[j], n);
  iact[nactive++] = i;
  active[i] = true;
  nred--;
  householder(Z, nactive, nred); // calculates new Z
  mattmat(W, Zold, Z, nred + 1, n, nred);     // W=Zold'*Z
  mattmat(B, W, B, nred + 1, nred, nred + 1); // B:=W'*B
  matmat(W, B, W, nred, nred + 1, nred);      // W:=W'*B*W
  for (Uint j = 0; j < nred; j++) copyvec(B[j], W[j], nred); // B:=W
  checkintegrity();
}

void BFGS::checkintegrity() const {
  assertinternal(nactive + nred == n);
  for (Uint i = 0; i < nred; i++) {
    for (Uint j = 0; j < nred; j++) {
      Double dp = dotprod(Z[i], Z[j], n);
      if (i == j) {assertinternal(fabs(dp - 1) < eps);}
      else {assertinternal(fabs(dp) < eps);}
    }
    for (Uint j = 0; j < nactive; j++) {
      Double dp = dotprod(Z[i], A[iact[j]], n);
      assertinternal(fabs(dp) < eps);
    }
  }
}

void BFGS::removeconstraint(Uint q) {
  nred++;
  Uint i = iact[q];
  for (Uint j = q; j < nactive - 1; j++)
    iact[j] = iact[j + 1];
  iact[nactive - 1] = i;
  householder(&(Z[nred - 1]), nactive - 1, 1);
  zero(B[nred - 1], nred - 1);
  for (Uint j = 0; j < nred; j++) B[j][nred - 1] = 0;
  B[nred - 1][nred - 1] = 1;
  active[i] = false;
  nactive--;
  checkintegrity();
}

void BFGS::numgrad(DoubleVec g, DoubleVec x, Double f) {
  ngrad++;
  DoubleVec xp;
  DoubleVec xm;
  NEWVEC(Double, xp, 3);
  NEWVEC(Double, xm, 3);
  copyvec(xp, x, n);
  if (central) {
    copyvec(xm, x, n);
    for (Uint i = 0; i < n; i++) {
      xp[i] = xp[i] + gradeps;
      xm[i] = xm[i] - gradeps;
      g[i] = (fun(xp, pos) - fun(xm, pos))/(2*gradeps);
      xm[i] = x[i];
      xp[i] = x[i];
    }
  }
  else {
    for (Uint i = 0; i < n; i++) {
      xp[i] = xp[i] + gradeps;
      g[i] = (fun(xp, pos) - f)/gradeps;
      xp[i] = x[i];
    }
  }
  DELETEVEC(xp);
  DELETEVEC(xm);
}

Double BFGS::linesearch(DoubleVec s, DoubleVec x, 
                        Double f, DoubleVec g,
                        Double &f1, DoubleVec g1, int &imax) {
  Double t0, t1, f0, d1;
  Double d = dotprod(g, s, n);
  DoubleVec xm;
  DoubleVec gm;
  DoubleVec x1;
  NEWVEC(Double, xm, 3);
  NEWVEC(Double, gm, 3);
  NEWVEC(Double, x1, 3);
  // ********* TODO cleanup xm, gm, x1
  if(sqrt(dotprod(s, s, n)) < eps) {
    f1 = f; copyvec(g1, g, n); imax = -1;
    return 0.0;
  };
  Double alfamax = maxstep(x, s, imax);
  t1 = min(1.0, alfamax);
  axpy(x1, t1, s, x, n);
  f1 = callfun(x1);
  numgrad(g1, x1, f1);
  d1 = dotprod(g1, s, n);
  if (linesearchcond(f, f1, d, d1)) {
    if (t1 < alfamax - eps) imax = -1;
    return t1; // alfa=1 fulfills linesearch condition
  }
  else if (alfamax > 1) { // find bracket [t0, t1]
    t0 = 0;
    f0 = f;
    while (t1 < alfamax && !bracketcond(f0, f1, d1) &&
           !linesearchcond(f, f1, d, d1)) {
      if (d1 < 0 && f1 < f) { 
        t0 = t1;
        f0 = f1;
      }
      t1 = min(alfamax, 2*t1);
      axpy(x1, t1, s, x, n);
      f1 = callfun(x1);
      numgrad(g1, x1, f1);
      d1 = dotprod(g1, s, n);
    }
    if (linesearchcond(f, f1, d, d1) ||
        alfamax <= t1 && !bracketcond(f0, f1, d1)) {
      if (t1 < alfamax - eps) imax = -1;
      return t1;
    }
  }
  else if (!bracketcond(f, f1, d1)) {
    if (t1 < alfamax - eps) imax = -1;
    return t1;
  }
  else {
    t0 = 0;
    f0 = f;
  }
  do { // find alfa in [t0, t1] that fulfills linesearch condition
    assertinternal(t0 < t1);
    Double tm = (t0 + t1)/2;
    assertinternal(tm <= alfamax);
    axpy(xm, tm, s, x, n);
    Double fm = callfun(xm);
    numgrad(gm, xm, fm);
    Double dm = dotprod(gm, s, n);
    if (linesearchcond(f, fm, d, dm) || t1 - t0 < eps) {
      f1 = fm;
      copyvec(g1, gm, n);
      imax = -1;
      return tm;
    }
    if (fm >= f0 || dm >= 0) {
      t1 = tm; f1 = fm; copyvec(g1, gm, n);
    }
    else {
      t0 = tm; f0 = fm;
    }
  } while (true); //******t1 - t0 > eps);
  fatal("linesearch failure in BFGS");
  return 0;
}

void BFGS::r1update(DoubleMat A, DoubleVec x, Double a, Uint n) {
  for(Uint i = 0; i < n; i++) axpy(A[i], a*x[i], x, A[i], n);
}

void BFGS::findsearchdir(DoubleVec s, DoubleVec g, DoubleVec p, DoubleVec Ztg) {
  if (nred == 0 || (nred == 1 || nred == 2) && det(B, nred) == 0) 
    zero(s, n);
  else {
    mattvec(Ztg, Z, g, n, nred);
    chol(B, p, Ztg, nred);
    scalvec(p, -1.0, p, nred);
    matvec(s, Z, p, n, nred);
  }
}

void BFGS::minimize(DoubleVec x, DoubleVec x1, Double &y1, Uint positn, 
                    Trianglefunction fn) {
  // Initial guess in x, return solution in x1
  nactive = 0; // No active constraints to begin with
  nred = n - nactive;
  for (Uint i = 0; i < nconstraints; i++)
    active[i] = false;
  fun = fn;
  pos = positn;
  DoubleVec p;
  DoubleVec Ztg;
  DoubleVec Ztg1;
  DoubleVec g;
  DoubleVec s;
  DoubleVec g1;
  DoubleVec y;
  DoubleVec Bp;
  DoubleVec lambda;
  NEWVEC(Double, p, 3);
  NEWVEC(Double, Ztg, 3);
  NEWVEC(Double, Ztg1, 3);
  NEWVEC(Double, g, 3);
  NEWVEC(Double, s, 3);
  NEWVEC(Double, g1, 3);
  NEWVEC(Double, y, 3);
  NEWVEC(Double, Bp, 3);
  NEWVEC(Double, lambda, 3);
  zero(B[0], n*n);
  for (Uint i = 0; i < n; i++) B[i][i] = 1;
  Double f = callfun(x), f1;
  y1 = f;
  numgrad(g, x, f);
  Double flk = f;
  Uint k;
  householder(Z, nactive, n);
  for (k = 1; k < maxit; k++) { 
    findsearchdir(s, g, p, Ztg);
    Double delta = -0.5*dotprod(s, g, n);
    assertcond(delta >= 0, "delta negative in BFGS");
#ifdef BFGS_DEBUG
    cout << "bfgs: " << k << ", f=" << f << ", delta=" << delta 
         << ", x=" << x[0] << " " << x[1] << endl;
#endif // BFGS_DEBUG
    if (delta < flk - f) {
      findlambda(lambda, g);
      Uint q;
      Double lambdamin = nactive > 0 ? vecmin<Double>(lambda, nactive, q) : 1.0e99;
      if (lambdamin >= 0 && delta <= conveps) break;
      else if (lambdamin < 0) {
        removeconstraint(q);
#ifdef BFGS_DEBUG
        cout << "removed constraint " << q << endl;
#endif // BFGS_DEBUG
        flk = f;
        findsearchdir(s, g, p, Ztg);
      }
    }
    int imax;
    Double alfa = linesearch(s, x, f, g, f1, g1, imax);
    axpy(x1, alfa, s, x, n);
    mattvec(Ztg1, Z, g1, n, nred);
    axpy(y, -1.0, Ztg, Ztg1, nred);
    mattvec(Bp, B, p, nred, nred);
    Double ytd = alfa*dotprod(y, p, nred);
    if (ytd > eps) {
      Double pBp = dotprod(Bp, p, nred); 
      assertcond(pBp >= eps, "nonpositive pBp in BFGS");
      r1update(B, y, 1/ytd, nred);
      r1update(B, Bp, -1/pBp, nred);
    }
    else if (alfa > 0) assertcond(imax >= 0, "ytd <= eps in BFGS");
    if (imax >= 0) {
     addconstraint(imax);
#ifdef BFGS_DEBUG
     cout << "added constraint " << imax << endl;
#endif // BFGS_DEBUG
    }
    f = f1;
    copyvec(x, x1, n);
    copyvec(g, g1, n);
  }
#ifdef BFGS_DEBUG
  cout << "nfun=" << nfun << "   ngrad=" << ngrad << endl;
#endif // BFGS_DEBUG
  if(k >= maxit) warning("Max step-count exceeded in BFGS");
  assertinternal(y1 + eps > f);
  y1 = f;
  DELETEVEC(p);
  DELETEVEC(Ztg);
  DELETEVEC(Ztg1);
  DELETEVEC(g);
  DELETEVEC(s);
  DELETEVEC(g1);
  DELETEVEC(y);
  DELETEVEC(Bp);
  DELETEVEC(lambda);
}
