#include "files.h"
#include "findnormal.h"
#include "qtlmodel.h"
#include "utils.h"
#include "inherdist.h"
#include "maximize.h"
#include "fmtout.h"
#include "options.h"
#include "bfgs.h"
#include "vecutil.h"

//////////////////////////////////////////////////////////////////////
// QTLmodel

QTLmodel::~QTLmodel() {
  DELETEVEC(lod);
  DELETEVEC(flod);
  DELETEVEC(Shat);
  DELETEVEC(Ghat);
  DELETEVEC(Ahat);
  DELETEVEC(Dhat);
  DELETEVEC(p);
  DELETEVEC(L);
}

void QTLmodel::calculateL(Family *fam, DoubleVec L, Double S,
                          Double G, Double A, Double D) {
  loglik.calc(fam, L, S, G, A, D);
}

Double QTLmodel::calculate(DoubleVec fl, Double S,
                           Double G, Double A, Double D, Uint gam) {
  Double l = 0.0;
  for (Uint ifam = 0; ifam < nfam(); ifam++) {
    Family *fam = distribution->families[ifam];
    calculateL(fam, L, S, G, A, D);
    if (fam->numiv > 1)
      ((Inherdist *)distribution)->getdist(fam->id, fam->numiv, gam, p);
    else p[0] = 1.0;
    Double fL = 0.0;
    if (conditionallikelihood) {
      Double normconst = 0.0;
      for (IV v = 0; v < fam->numiv; v++) {
        Double expL = exp(L[v]);
        fL += p[v]*expL;
        normconst += expL;
      }
      normconst /= fam->numiv;
      l += fl[ifam] = log10(fL/normconst);
    }
    else {
      for (IV v = 0; v < fam->numiv; v++) fL += p[v]*exp(L[v]);
      l += fl[ifam] = log10(fL);
    }
  }
  return l;
}

QTLmodel *maxcaller = 0;
DoubleVec fl = 0;

Double Gglob = 0.0;

// Function of shared variance only
Double maxfunc_G(Double G, Uint) {
  return maxcaller->calculate(fl, 1.0 - G, G, 0, 0, 0);
}

// Function of additive variance only
Double maxfunc_A(Double A, Uint pos) {
  return maxcaller->calculate(fl, 1.0 - Gglob - A, Gglob, A, 0, pos);
}

// Function of dominance variance only
Double maxfunc_D(Double D, Uint pos) {
  return maxcaller->calculate(fl, 1.0 - Gglob - D, Gglob, 0, D, pos);
}

// Function of shared and additive variance (in that order)
Double minfunc_GA(DoubleVec x, Uint pos) {
  return -maxcaller->calculate(fl, 1.0 - x[0] - x[1],
                               x[0], x[1], 0, pos);
}

// Function of shared, additive and dominance variance (in that order)
Double minfunc_GAD(DoubleVec x, Uint pos) {
  return -maxcaller->calculate(fl, 1.0 - x[0] - x[1] - x[2],
                               x[0], x[1], x[2], pos);
}

// Function of shared and dominance variance (in that order)
Double minfunc_GD(DoubleVec x, Uint pos) {
  return -maxcaller->calculate(fl, 1.0 - x[0] - x[1],
                               x[0], 0, x[1], pos);
}

// Function of addtitive and dominance variance (in that order)
Double minfunc_AD(DoubleVec x, Uint pos) {
  return -maxcaller->calculate(fl, 1.0 - Gglob - x[0] - x[1],
                               Gglob, x[0], x[1], pos);
}

// Function of residual
Double maxfunc_S(Double S, Uint) {
  return maxcaller->calculate(fl, S, 0, 0, 0, 0);
}

// Function of residual and shared variance only
Double minfunc_SG(DoubleVec x, Uint) {
  return -maxcaller->calculate(fl, x[0], x[1], 0, 0, 0);
}

// Function of residual and additive variance only
Double minfunc_SA(DoubleVec x, Uint pos) {
  return -maxcaller->calculate(fl, x[0], Gglob, x[1], 0, pos);
}

// Function of residual and dominance variance only
Double minfunc_SD(DoubleVec x, Uint pos) {
  return -maxcaller->calculate(fl, x[0], Gglob, 0, x[1], pos);
}

// Function of residual and shared and additive variance (in that order)
Double minfunc_SGA(DoubleVec x, Uint pos) {
  return -maxcaller->calculate(fl, x[0], x[1], x[2], 0, pos);
}

// Function of residual and shared, additive and dominance variance
// (in that order)
Double minfunc_SGAD(DoubleVec x, Uint pos) {
  return -maxcaller->calculate(fl, x[0], x[1], x[2], x[4], pos);
}

// Function of residual and shared and dominance variance (in that
// order)
Double minfunc_SGD(DoubleVec x, Uint pos) {
  return -maxcaller->calculate(fl, x[0], x[1], 0, x[2], pos);
}

// Function of residual and addtitive and dominance variance (in that
// order)
Double minfunc_SAD(DoubleVec x, Uint pos) {
  return -maxcaller->calculate(fl, x[0], Gglob, x[1], x[2], pos);
}

void QTLmodel::run() {
  DoubleVec flnull;
  DoubleVec flalt;
  NEWVEC(Double, flnull, nfam());
  NEWVEC(Double, flalt, nfam());

  maxcaller = this;

  const Double UNCONSTRAINEDMAX = 1e+10;
  
  // Maximize likelihood under H_0
  Double llnull;
  fl = flnull;
  if (maximizeS) 
    if (maximizeG) 
      bivariatemaximization(&Snullhat, UNCONSTRAINEDMAX, &Gnullhat,
                            UNCONSTRAINEDMAX, &llnull, minfunc_SG, 1);
    else
      univariatemaximization(&Snullhat, UNCONSTRAINEDMAX, &llnull, -1e300,
                             maxfunc_S, 1);
  else
    if (maximizeG && (maximizeA || maximizeD))
      univariatemaximization(&Gnullhat, 1, &llnull, -1e300, maxfunc_G, 1);
    else {
      llnull = calculate(flnull, resid, (maximizeG ? 0.0 : shared), 0, 0, 0);
      if (maximizeG) Gnullhat = 0.0;
    }

  // Maximize likelihood under H_A
  fl = flalt;
  if (maximizeS) // Maximize over residual variance
    if (maximizeG) // Maximize over shared variance
      if (maximizeA)  // Maximize over additive variance
        if (maximizeD) // Maximize over dominance variance
          quadvariatemaximization(Shat, Ghat, Ahat, Dhat, lod, minfunc_SGAD,
                                  npos());
        else // Dominance variance fixed
          trivariatemaximization(Shat, UNCONSTRAINEDMAX, Ghat, UNCONSTRAINEDMAX,
                                 Ahat, UNCONSTRAINEDMAX, lod, minfunc_SGA,
                                 npos());
      else // Additive variance fixed
        if (maximizeD) // Maximize over dominant variance
          trivariatemaximization(Shat, UNCONSTRAINEDMAX, Ghat, UNCONSTRAINEDMAX,
                                 Dhat, UNCONSTRAINEDMAX, lod, minfunc_SGD,
                                 npos());
        else // Dominance variance fixed
          assertinternal(false)
    else { // Shared variance fixed
      Gglob = shared;
      if (maximizeA)  // Maximize over additive variance
        if (maximizeD) // Maximize over dominant variance
          trivariatemaximization(Shat, UNCONSTRAINEDMAX, Ahat, UNCONSTRAINEDMAX,
                                 Dhat, UNCONSTRAINEDMAX, lod, minfunc_SAD,
                                 npos());
        else // Dominance variance fixed
          bivariatemaximization(Shat, UNCONSTRAINEDMAX, Ahat, UNCONSTRAINEDMAX,
                                 lod, minfunc_SA, npos());
      else // Additive variance fixed
        if (maximizeD) // Maximize over dominant variance
          bivariatemaximization(Shat, UNCONSTRAINEDMAX, Dhat, UNCONSTRAINEDMAX,
                                lod, minfunc_SD, npos());
        else // Dominance variance fixed
          assertinternal(false);
    }
  else // residual variance fixed
    if (maximizeG) { // Maximize over shared variance
      if (maximizeA) { // Maximize over additive variance
        if (maximizeD) // Maximize over dominant variance
          trivariatemaximization(Ghat, 1, Ahat, 1, Dhat, 1, lod, minfunc_GAD,
                                 npos());
        else // Dominance variance fixed
          bivariatemaximization(Ghat, 1, Ahat, 1, lod, minfunc_GA, npos());
      }
      else { // Additive variance fixed
        if (maximizeD) // Maximize over dominant variance
          bivariatemaximization(Ghat, 1, Dhat, 1, lod, minfunc_GD, npos());
        else // Dominance variance fixed
          univariatemaximization(Ghat, 1, lod, llnull, maxfunc_G, npos());
      }
    }
    else { // Shared variance fixed
      Gglob = shared;
      if (maximizeA)  // Maximize over additive variance
        if (maximizeD) // Maximize over dominant variance
          bivariatemaximization(Ahat, 1 - shared, Dhat, 1 - shared, lod,
                                minfunc_AD, npos());
        else // Dominance variance fixed
          univariatemaximization(Ahat, 1 - shared, lod, llnull,
                                 maxfunc_A, npos());
      else // Additive variance fixed
        if (maximizeD) // Maximize over dominant variance
          univariatemaximization(Dhat, 1 - shared, lod, llnull,
                                 maxfunc_D, npos());
        else // Dominance variance fixed
          assertinternal(false);
    }

  for (Uint pos = 0; pos < npos(); pos++) {
    lod[pos] -= llnull;
    for (Uint fam = 0; fam < nfam(); fam++)
      flod[fam][pos] -= flnull[fam];
  }
  DELETEVEC(flalt);
  DELETEVEC(flnull);
}

void QTLmodel::quadvariatemaximization(DoubleVec shat, DoubleVec xhat,
                                       DoubleVec yhat, DoubleVec zhat,
                                       DoubleVec fhat,
                                       BFGS::Trianglefunction fun,
                                       Uint numpos) {
  BFGS bfgs(100, 4, 4, 1e-7, 1e-5, 0.1, false);

  // Create corners of region
  const Double zeroeps = 5e-6;
  DoubleMat c;
  NEWMAT(Double, c, 4, 4);
  zero(c[0], 4*4);

  c[0][0] = -1;
  c[1][1] = -1;
  c[2][2] = -1;
  c[3][3] = -1;

  // Set constraints
  bfgs.addconstraint(c[0], zeroeps);
  bfgs.addconstraint(c[1], zeroeps);
  bfgs.addconstraint(c[2], zeroeps);
  bfgs.addconstraint(c[3], zeroeps);

  // Perform maximization
  DoubleVec x0;
  DoubleVec xstar;
  NEWVEC(Double, x0, 3);
  NEWVEC(Double, xstar, 3);
  for (Uint pos = 0; pos < numpos; pos++) {
    copyval(x0, .1, 3);
    bfgs.minimize(x0, xstar, fhat[pos], pos, fun);
    fhat[pos] = -fhat[pos];
    shat[pos] = xstar[0];
    xhat[pos] = xstar[1];
    yhat[pos] = xstar[2];
    zhat[pos] = xstar[3];
    
    for (Uint f = 0; f < nfam(); f++) flod[f][pos] = fl[f];
  }
  DELETEVEC(x0);
  DELETEVEC(xstar);
}

void QTLmodel::trivariatemaximization(DoubleVec xhat, Double xmax,
                                      DoubleVec yhat, Double ymax,
                                      DoubleVec zhat, Double zmax,
                                      DoubleVec fhat,
                                      BFGS::Trianglefunction fun,
                                      Uint numpos) {
  BFGS bfgs(100, 3, 4, 1e-7, 1e-5, 0.1, false);

  // Create corners of region
  const Double corneps = 1e-4;
  const Double zeroeps = 5e-6;
  DoubleMat c;
  NEWMAT(Double, c, 4, 3);

  c[0][0] = zeroeps;
  c[0][1] = zeroeps;
  c[0][2] = zeroeps;
  c[1][0] = zeroeps;
  c[1][1] = (1 - corneps)*ymax;
  c[1][2] = zeroeps;
  c[2][0] = (1 - corneps)*xmax;
  c[2][1] = zeroeps;
  c[2][2] = zeroeps;
  c[3][0] = zeroeps;
  c[3][1] = zeroeps;
  c[3][2] = (1 - corneps)*zmax;

  // Set constraints
  DoubleMat d;
  NEWVEC(DoubleVec, d, 2);

  d[0] = c[0];
  d[1] = c[2];
  d[2] = c[1];
  bfgs.addconstraint(d);

  d[0] = c[0];
  d[1] = c[3];
  d[2] = c[2];
  bfgs.addconstraint(d);

  d[0] = c[0];
  d[1] = c[1];
  d[2] = c[3];
  bfgs.addconstraint(d);

  d[0] = c[2];
  d[1] = c[3];
  d[2] = c[1];
  bfgs.addconstraint(d);

  // Perform maximization
  DoubleVec x0;
  DoubleVec xstar;
  NEWVEC(Double, x0, 3);
  NEWVEC(Double, xstar, 3);
  for (Uint pos = 0; pos < numpos; pos++) {
    copyval(x0, .1, 3);
    bfgs.minimize(x0, xstar, fhat[pos], pos, fun);
    fhat[pos] = -fhat[pos];
    xhat[pos] = xstar[0];
    yhat[pos] = xstar[1];
    zhat[pos] = xstar[2];
    
    for (Uint f = 0; f < nfam(); f++) flod[f][pos] = fl[f];
  }
  DELETEVEC(x0);
  DELETEVEC(xstar);
}

void QTLmodel::bivariatemaximization(DoubleVec xhat, Double xmax,
                                     DoubleVec yhat, Double ymax,
                                     DoubleVec fhat,
                                     BFGS::Trianglefunction fun, Uint numpos) {
  assertinternal(xmax > 0 && ymax > 0);
  BFGS bfgs(100, 2, 3, 1e-7, 1e-5, 0.1, false);

  // Create corners of region
  const Double corneps = 1e-4;
  const Double zeroeps = 5e-6;
  DoubleMat c;
  DoubleMat d;
  NEWMAT(Double, c, 3, 2);
  NEWVEC(DoubleVec, d, 2);
  
  c[0][0] = zeroeps;
  c[0][1] = zeroeps;
  c[1][0] = zeroeps;
  c[1][1] = ymax - corneps;
  c[2][0] = xmax - corneps;
  c[2][1] = zeroeps;

  // Set constraints
  d[0] = c[1];
  d[1] = c[0];
  bfgs.addconstraint(d);

  d[0] = c[2];
  d[1] = c[1];
  bfgs.addconstraint(d);

  d[0] = c[0];
  d[1] = c[2];
  bfgs.addconstraint(d);
    
  DELETEMAT(c);
  DELETEVEC(d);

  // Perform maximization
  DoubleVec x0;
  DoubleVec xstar;
  NEWVEC(Double, x0, 2);
  NEWVEC(Double, xstar, 2);
  for (Uint pos = 0; pos < numpos; pos++) {
    copyval(x0, .1, 2);
    bfgs.minimize(x0, xstar, fhat[pos], pos, fun);
    fhat[pos] = -fhat[pos];
    xhat[pos] = xstar[0];
    yhat[pos] = xstar[1];

    for (Uint f = 0; f < nfam(); f++) flod[f][pos] = fl[f];
  }
  DELETEVEC(x0);
  DELETEVEC(xstar);
}

void QTLmodel::univariatemaximization(DoubleVec xhat, Double xmax,
                                      DoubleVec fhat, Double fmin,
                                      Maximizor::likelihoodfunction fun,
                                      Uint numpos) {
  Double upperbound = (1 - 1.0e-9)*xmax;
  Double lowerbound = 0;
  Maximizor maximizor(lowerbound, upperbound, 50, 4, false);
  for (Uint pos = 0; pos < numpos; pos++) {
    maximizor.maximize(fmin, 0, fun, pos, xhat[pos], fhat[pos]);

    for (Uint f = 0; f < nfam(); f++) flod[f][pos] = fl[f];
  }
}

void QTLmodel::initialize() {
  NEWVEC(Double, lod, npos());
  NEWMAT(Double, flod, nfam(), npos());
  if (maximizeA) NEWVEC(Double, Ahat, npos());
  if (maximizeD) NEWVEC(Double, Dhat, npos());
  if (maximizeG) NEWVEC(Double, Ghat, npos());
  if (maximizeS) NEWVEC(Double, Shat, npos());

  Uint maxnumiv = 0;
  for (Uint ifam = 0; ifam < nfam(); ifam++)
    maxnumiv = max_(maxnumiv, distribution->families[ifam]->numiv);
  NEWVEC(Double, p, maxnumiv);
  NEWVEC(Double, L, maxnumiv);
}

void QTLmodel::output() {
  initialize();
  run();
  Output::output();
}

void QTLmodel::totline(ostream &f, Uint pos) {
  fmtout(f, 9, 4, lod[pos]);
  if (maximizeG) fmtout(f, 10, 4, Ghat[pos]);
  if (maximizeA) fmtout(f, 10, 4, Ahat[pos]);
  if (maximizeD) fmtout(f, 10, 4, Dhat[pos]);
  if (maximizeG) fmtout(f, 10, 4, Gnullhat);
  f << "   ";
}

void QTLmodel::famline(ostream &f, Uint ifa, Uint pos) {
  fmtout(f, 9, 4, flod[ifa][pos]);
  f << "   ";
}

void QTLmodel::totheader(ostream &f) {
  f << "    LOD    ";
  if (maximizeG) f << " shared   ";
  if (maximizeA) f << "additive  ";
  if (maximizeD) f << "dominance ";
  if (maximizeG) f << " shrnull  ";
}

void QTLmodel::famheader(ostream &f) {
  f << "    LOD    ";
}
