#include "files.h"
#include "asmmodel.h"
#include "asmexp.h"
#include "options.h"
#include "maximize.h"
#include "scoredistfull.h"
#include "utils.h"
#include "convolve.h"
#include <iostream>
#include <iomanip>
#include <math.h>
#include "modelweight.h"

#define TOL (1.0e-8)

/////////////////////////////////

void ASMexp::compute_wnpl() {
  for (Uint pos = 0; pos < distribution->nnulldist(); pos++)
    for (Uint fam = 0; fam < nfam(); fam++)
      for (Uint j = 0; j < nscores[fam]; j++)
        if (nullsd[pos][fam] < TOL) zscore[pos][fam][j] = 0;
        else
          zscore[pos][fam][j] = (value[fam][j] -
                                 nullmean[pos][fam])/nullsd[pos][fam];
  
  for (Uint fam = 0; fam < nfam(); fam++)
    for (Uint pos = 0; pos < npos(); pos++) {
      wnpl[fam][pos] = 0.0;
      for (Uint j = 0; j < nscores[fam]; j++)
        wnpl[fam][pos] += zscore[pos][fam][j]*prob[fam][pos][j];
      wnpl[fam][pos] = weight[pos][fam]*wnpl[fam][pos];
    }
}

void ASMexp::compute_fam_info() const {
  for (Uint fam = 0; fam < nfam(); fam++) {
    for (Uint pos = 0; pos < npos(); pos++) {
      Double mean = 0.0, EZ2 = 0.0;
      for (Uint j = 0; j < nscores[fam]; j++) {
        mean += zscore[pos][fam][j]*prob[fam][pos][j];
        EZ2 += zscore[pos][fam][j]*zscore[pos][fam][j]* prob[fam][pos][j];
      }
      Double var = EZ2 - mean*mean;
      faminfo[fam][pos] = max_<Double, Double>(-99.0, 1 - var);
    }
  }
}

/////////////////////////////////

const Double ALLEGROLOG10 = log(10.0);

ASMexp::ASMexp(Calcscore *sc, const string &pt) :
    ASMmodel(pt, "exp") {
  distribution = Scoredist::getscoredist(pt, sc);
}

ASMexp const *expcaller;

void ASMexp::initialize() {
  ASMmodel::initialize();
  dolodexactp = options->lodexactp;
  donplexactp = options->nplexactp;
  NEWVEC(Uint, nscores, nfam());
  nullprob = new DoubleMat[npos()];
  if (Distribution::isnullconstant()) {
    NEWVEC(DoubleVec, nullprob[0], nfam());
    for (Uint pos = 0; pos < npos(); pos++)
      nullprob[pos] = nullprob[0];
  }
  else 
    for (Uint pos = 0; pos < npos(); pos++) {
      NEWVEC(DoubleVec, nullprob[pos], nfam());
    }
  
  prob = new DoubleMat[nfam()];
  NEWVEC(DoubleVec, value, nfam());
  if (dolodexactp) {
    NEWVEC(Double, lodp, ndelta);
    NEWVEC(Double, lodpi, ndelta);
    NEWVEC(Double, nplpi, ndelta);
    NEWVEC(Double, lodexactps, npos());
  }
  if (donplexactp) NEWVEC(Double, nplexactps, npos());
  ((Scoredistfull *)distribution)->getresults(nscores, value, prob, nullprob);

  // Calculate mean and stdeviation under the null
  for (Uint pos = 0; pos < distribution->nnulldist(); pos++) {
    for (Uint fam = 0; fam < nfam(); fam++) {
      Double mean = 0.0;
      Double S2 = 0.0;
      for (Uint j = 0; j < nscores[fam]; j++) {
        mean += nullprob[pos][fam][j]*value[fam][j];
        S2 += nullprob[pos][fam][j]*value[fam][j]*value[fam][j];
      }
      nullmean[pos][fam] = mean;
      nullsd[pos][fam] = sqrt(S2 - mean*mean);
      if (nullsd[pos][fam] < TOL) nullsd[pos][fam] = 0;
    }
  }

  zscore = new DoubleMat[npos()];
  if (Distribution::isnullconstant()) {
    NEWVEC(DoubleVec, zscore[0], nfam());
    for (Uint fam = 0; fam < nfam(); fam++)
      NEWVEC(Double, zscore[0][fam], nscores[fam]);
    for (Uint pos = 1; pos < npos(); pos++)
      zscore[pos] = zscore[0];
  }
  else
    for (Uint pos = 0; pos < npos(); pos++) {
      NEWVEC(DoubleVec, zscore[pos], nfam());
      for (Uint fam = 0; fam < nfam(); fam++)
        NEWVEC(Double, zscore[pos][fam], nscores[fam]);
    }
  
  expcaller = this;
}

void ASMexp::cleanup() {
  ASMmodel::cleanup();
  DELETEVEC(nscores);
  delete [] prob;
  if (dolodexactp) {
    DELETEVEC(lodp);
    DELETEVEC(lodpi);
    DELETEVEC(nplpi);
    DELETEVEC(lodexactps);
  }
  if (donplexactp) DELETEVEC(nplexactps);
  if (Distribution::isnullconstant()) {
    for (Uint fam = 0; fam < nfam(); fam++)
      DELETEVEC(zscore[0][fam]);
    DELETEVEC(zscore[0]);
    DELETEVEC(nullprob[0]);
  }
  else {
    for (Uint pos = 0; pos < npos(); pos++) {
      for (Uint fam = 0; fam < nfam(); fam++) {
        DELETEVEC(zscore[pos][fam]);
//        DELETEVEC(nullprob[pos][fam]);
      }
      DELETEVEC(zscore[pos]);
      DELETEVEC(nullprob[pos]);
    }
  }
  delete [] nullprob;
  delete [] zscore;
}

Double f_exp(Double x, Uint pos) {
  const Double LOG_MAXDOUBLE = 709;
  const Double LOG_MINDOUBLE = -708;

  Double y = 0.0;
  Double constant = 0.0;
  for (Uint fam = 0; fam < expcaller->nfam(); fam++) {
    Double fam_constant = 0.0;
    Double fam_prob = 0.0;
    for (Uint j = 0; j < expcaller->nscores[fam]; j++) {
      Double exponent = x*expcaller->weight[pos][fam]*expcaller->zscore[pos][fam][j];
      if (exponent > LOG_MAXDOUBLE || exponent < LOG_MINDOUBLE) return 0.0;
      Double dtmp = exp(exponent);
      fam_constant += expcaller->nullprob[pos][fam][j]*dtmp;
      fam_prob += expcaller->prob[fam][pos][j]*dtmp;
    }
    if (fam_prob < 0 || fam_constant < 0) return -1000;
    y += log(fam_prob);
    constant += log(fam_constant);
  }
  y -= constant;
  if(fabs(y) < TOL) y = 0.0;
  return y;
}

#define MAXDHAT 99.0

void ASMexp::allegroexponential() {
  Maximizor mz(-5, 5, 100, 11, true);
  for (int i = 0; i < npos(); i++) {
    if (pt == "spt" ||
        map->inbetween[i] || map->shouldfindp[map->leftmarker[i]]) {
      mz.maximize(0, 0, f_exp, i, dhat[i], lod[i]);
      if (fabs(dhat[i]) > MAXDHAT) {
        dhat[i] = dhat[i] > 0 ? MAXDHAT : -MAXDHAT;
        lod[i] = f_exp(dhat[i], i);
      }
      lod[i] /= ALLEGROLOG10;
    }
  }

  // compute_fam_lod
  for (Uint pos = 0; pos < npos(); pos++) {
    for (Uint fam = 0; fam < nfam(); fam++) {
      flod[fam][pos] = 0.0;
      Double fam_constant = 0.0;
      for (Uint j = 0; j < nscores[fam]; j++) {
        Double dtmp = exp(dhat[pos]*weight[pos][fam]*zscore[pos][fam][j]);
        flod[fam][pos] += (prob[fam][pos][j]*dtmp);
        fam_constant += (nullprob[pos][fam][j]*dtmp);
      }
      flod[fam][pos] = log(flod[fam][pos]/fam_constant)/log(10.0);
    }
  }
}

Double ASMexp::RIp(Uint pos) const {
  /* info 3 */
  Double r = 0.0;
  for (Uint fam = 0; fam < nfam(); fam++) {
    Double n0 = 0.0, n1 = 0.0;
    for (Uint j = 0; j < nscores[fam]; j++) {
      Double t1 = exp(dhat[pos]*weight[pos][fam]*zscore[pos][fam][j])*
        nullprob[pos][fam][j];
      n0 += t1;
      t1 *= zscore[pos][fam][j];
      n1 += t1;
    }
    n1 *= weight[pos][fam];
    r += (dhat[pos]*n1)/n0 - log(n0);
  }
  if(fabs(r) < TOL) {
    return 0.0;
  } else {
    return (lod[pos]*ALLEGROLOG10)/r;
  }  
}

Double ASMexp::RIf(Uint pos) const {
  /* info 1 (Rubin's measure) */
  Double a = 0.0, b = 0.0;
  for (int fam = 0; fam < nfam(); fam++) {
    Double n0, n1, n2;
    Double c0, c1, c2;
    n0 = n1 = n2 = 0.0;
    c0 = c1 = c2 = 0.0;
    for (int j = 0; j < nscores[fam]; j++) {
      Double t1 = exp(dhat[pos]*weight[pos][fam]*zscore[pos][fam][j]);
      Double t2 = t1*prob[fam][pos][j];
      t1 *= nullprob[pos][fam][j];
      n0 += t1;
      c0 += t2;
      t1 *= zscore[pos][fam][j];
      t2 *= zscore[pos][fam][j];
      n1 += t1;
      c1 += t2;
      t1 *= zscore[pos][fam][j];
      t2 *= zscore[pos][fam][j];
      n2 += t1;
      c2 += t2;
    }
    n1 *= weight[pos][fam];
    c1 *= weight[pos][fam];
    Double t = weight[pos][fam]*weight[pos][fam];
    n2 *= t;
    c2 *= t;
    a += (n2*n0 - n1*n1)/(n0*n0);
    b += (c2*c0 - c1*c1)/(c0*c0);
  }

  return (a - b)/a;
}

void ASMexp::allegroexpinfo() const {
  const Double INFOTHRESH = .01;
  for (Uint pos = 0; pos < npos(); pos++) {
    if (!(pt == "spt" ||
          map->inbetween[pos] || map->shouldfindp[map->leftmarker[pos]]))
      continue;
    information[2][pos] = (fabs(npl[pos]) < INFOTHRESH ? RIf(pos) : RIp(pos));
  }
}

void ASMexp::run() {
  compute_wnpl();
  allegroexponential();
  compute_final_stats();
  allegroexpinfo();
  if (dolodexactp || donplexactp) exactp();
}

Double f_delta(Double delta, Double &lod, Double &npl) {
  Double dtmp;
  Double fam_constant;
  Double fam_prob;

  Double y, sm;
  
  sm = y = 0.0;
  for (Uint fam = 0; fam < expcaller->nfam(); fam++) {
    fam_constant = 0.0;
    fam_prob = 0.0;
    for (Uint j = 0; j < expcaller->nscores[fam]; j++) {
      dtmp = expcaller->nullprob[0][fam][j]*exp(delta*expcaller->zscore[0][fam][j]);
      fam_constant += dtmp;
      fam_prob += expcaller->zscore[0][fam][j]*dtmp;
    }
    y += fam_prob/fam_constant;
    sm += log(fam_constant);
  }
  if(fabs(y) < TOL) y = 0.0;
  npl = y;
  lod = (delta*y - sm)/ALLEGROLOG10;
  lod = delta < 0 ? -lod : lod;
  return y;
}

void ASMexp::calclodps(const Convolver &cv,
                       Double leftbound, Double rightbound) {
  Double delta = leftbound;
  Uint ndelta = 10001;
  for (Uint i = 0; i < ndelta; i++) {
    lodp[i] = cv.calcpvalue(f_delta(delta, lodpi[i], nplpi[i]));
    delta += (rightbound - leftbound)/Double(ndelta - 1);
  }
}

Double ASMexp::calclodpvalue(Double x) const {
  Uint i = 0;
  while (i < ndelta && lodpi[i] < x) i++;
  i--;
  if (i == ndelta) return 0.0;
  else if (i == ndelta - 1) return lodp[i];
  else
    return lodp[i] + (x - lodpi[i])/(lodpi[i + 1] - lodpi[i])*
      (lodp[i + 1] - lodp[i]);
}


void ASMexp::exactp() {
  Convolver cv(50000, nfam());
  for (Uint fam = 0; fam < nfam(); fam++)
    for (int j = 0; j < nscores[fam]; j++)
      zscore[0][fam][j] *= weight[0][fam];
  cv.convolve(zscore[0], nullprob[0], nscores);
  if (options->nplpfile) {
    string fn = "p" + outfile.name;
    cv.write(fn);
  }
  if (dolodexactp) {
    Double minlod = 0;
    for (Uint pos = 0; pos < npos(); pos++)
      if (dhat[pos] < 0 && minlod > -lod[pos]) minlod = -lod[pos];
    Double maxlod = vecmax<Double>(lod, npos());
    Double leftbound = -5, rightbound = 5;
    Double mindistlod = minlod + 1;
    Double maxdistlod = maxlod - 1;
    while (mindistlod > minlod || maxdistlod < maxlod) {
      calclodps(cv, leftbound, rightbound);
      mindistlod = lodpi[0];
      maxdistlod = lodpi[ndelta - 1];
      if (mindistlod > minlod) leftbound -=5;
      if (maxdistlod < maxlod) rightbound += 5;
    }
//      ofstream f("lodp.dat");
//      for (int i = 0; i < ndelta; i++)
//        f << lodpi[i] << "\t" << lodp[i] << "\t" << nplpi[i] << "\n";
    for (Uint pos = 0; pos < npos(); pos++)
      lodexactps[pos] = calclodpvalue((dhat[pos] < 0 ? -1 : 1)*lod[pos]);
  }
  if (donplexactp)
    for (Uint pos = 0; pos < npos(); pos++)
      nplexactps[pos] = cv.calcpvalue(npl[pos]*sqrt(double(nfam())));
}

void ASMexp::printfaminfo() const {
  Outfile faminfo;
  string on = outfile.name;
  string::size_type lastslash = on.find_last_of("/");
  string fname;
  if (lastslash == string::npos) fname = "i" + on;
  else fname = (on.substr(0, lastslash + 1) + "i" + on.substr(lastslash + 1));
  faminfo.setname(fname, fname);
  faminfo.open();
  for (unsigned int f = 0; f < nfam(); f++) {
    faminfo << getfamid(f) << " "
            << modelweight->assign(getfamid(f), nullsd[0][f]) << endl;
  }
  faminfo.close();
}
