#include "files.h"
#include "asmmodel.h"
#include "asmpoly.h"
#include "files.h"
#include <iostream>
#include <iomanip>
#include <math.h>
#include "maximize.h"
#include "scoredistmoments.h"
#include "utils.h"
#include "options.h"
#include "fmtout.h"
#include "modelweight.h"

#define TOL (1.0e-8)

////////////////////

void ASMpoly::compute_fam_npl() {
  for(Uint fam = 0; fam < nfam(); fam++)
    for(Uint pos = 0; pos < npos(); pos++)
      wnpl[fam][pos] = weight[pos][fam]*zvalue[fam][pos][0];
}

//////////////////

const Double ALLEGROLOG10 = log(10.0);

ASMpoly::ASMpoly(Calcscore *sc, const string &pt, const string &mo, Uint deg) :
    ASMmodel(pt, mo), degree(deg) {
  dolodexactp = donplexactp = false;
  distribution = Scoredist::getscoredist(pt, sc, degree);
}

void ASMpoly::initialize() {
  ASMmodel::initialize();
  zvalue = new DoubleMat[nfam()];
  znullvalue = new DoubleMat[npos()];
  
  if (Distribution::isnullconstant()) {
    NEWMAT(Double, znullvalue[0], nfam(), degree);
    for (Uint pos = 1; pos < npos(); pos++)
      znullvalue[pos] = znullvalue[0];
  }
  else
    for (Uint pos = 0; pos < npos(); pos++)
      NEWMAT(Double, znullvalue[pos], nfam(), degree);

  for (Uint fam = 0; fam < nfam(); fam++)
    NEWMAT(Double, zvalue[fam], npos(), degree);

  NEWVEC(Double, minvalue, nfam());
  NEWVEC(Double, maxvalue, nfam());
  
  ((Scoredistmoments *)distribution)->getresults(zvalue, znullvalue,
                                                 nullmean, nullsd, degree,
                                                 minvalue, maxvalue);
}
  
void ASMpoly::cleanup() {
  ASMmodel::cleanup();
  if (distribution->isnullconstant())
    DELETEMAT(znullvalue[0]);
  else
    for (Uint pos = 0; pos < npos(); pos++)
      DELETEMAT(znullvalue[pos]);
  
  for (Uint fam = 0; fam < nfam(); fam++)
    DELETEMAT(zvalue[fam]);
  delete [] zvalue;
  delete [] znullvalue;

  DELETEVEC(minvalue);
  DELETEVEC(maxvalue);
}

ASMpoly const *polycaller;
Double ASMpoly::f_lin(Double x, Uint pos) {
  Double y = 0.0;
  for(Uint fam = 0; fam < polycaller->nfam(); fam++) {
    Double L = 1.0;
    Double L0 = 1.0;
    Double xx = x*polycaller->weight[pos][fam];
    for (Uint deg = 0; deg < polycaller->degree; deg++) {
      L += xx*polycaller->zvalue[fam][pos][deg];
      L0 += xx*polycaller->znullvalue[pos][fam][deg];
      xx *= x*polycaller->weight[pos][fam]/Double(deg + 2);
    }
    y += log(L) - log(L0);
  }
  return y;
}

#define MAXDHAT 99.0

Double ASMpoly::RIp(Uint pos) const {
  /* info 3 */
  Double r = 0.0;
  for(Uint fam = 0; fam < nfam(); fam++) {
    Double n0 = 1.0, n1 = 0.0;
    Double dg = dhat[pos]*weight[pos][fam];
    Double dgp = 1;
    
    for (Uint deg = 0; deg < degree; deg++) {
      n1 += dgp*znullvalue[pos][fam][deg];
      dgp *= dg/Double(1 + deg);
      n0 += dgp*znullvalue[pos][fam][deg];
    }
    r += dg*n1/n0 - log(n0);
  }
  if(fabs(r) < TOL) {
    return 0.0;
  } else {
    return (lod[pos]*ALLEGROLOG10)/r;
  }  
}

////
Double p(Double x, DoubleVec Z, Uint degree) {
  Double L = 1.0;
  Double xx = x;
  for (Uint deg = 0; deg < degree; deg++) {
    L += xx*Z[deg];
    xx *= x/Double(deg + 2);
  }
  return L;
}

Double p_m(Double x, DoubleVec Z, Uint degree) {
  Double L = 0.0;
  Double xx = 1.0;
  for (Uint deg = 0; deg < degree; deg++) {
    L += xx*Z[deg];
    xx *= x/Double(deg + 1);
  }
  return L;
}

Double p_mm(Double x, DoubleVec Z, Uint deg) {
  return p_m(x, Z + 1, deg - 1);
}
  
Double ASMpoly::RIf(Uint pos) const {
  Double l0 = 0.0;
  Double l1 = 0.0;
  for (Uint fam = 0; fam < nfam(); fam++) {
    Double p0 = p(dhat[pos]*weight[pos][fam], znullvalue[pos][fam], degree);
    Double p0_m = p_m(dhat[pos]*weight[pos][fam], znullvalue[pos][fam], degree);
    Double p0_mm = p_mm(dhat[pos]*weight[pos][fam],
                        znullvalue[pos][fam], degree);
    Double p1 = p(dhat[pos]*weight[pos][fam], zvalue[fam][pos], degree);
    Double p1_m = p_m(dhat[pos]*weight[pos][fam], zvalue[fam][pos], degree);
    Double p1_mm = p_mm(dhat[pos]*weight[pos][fam], zvalue[fam][pos], degree);
    l0 += (p0_mm*p0 - p0_m*p0_m)/(p0*p0);
    l1 += (p1_mm*p1 - p1_m*p1_m)/(p1*p1);
  }
  return 1 - l1/l0;
}

void ASMpoly::allegroinfo() const {
  const Double INFOTHRESH = .01;
  for(Uint pos = 0; pos < npos(); pos++) {
    if (!(pt == "spt" ||
          map->inbetween[pos] || map->shouldfindp[map->leftmarker[pos]]))
      continue;
    information[2][pos] = (fabs(npl[pos]) < INFOTHRESH && degree > 2 ?
                           RIf(pos) : RIp(pos));
  }
}

void ASMpoly::linearbounds(Double &dmin, Double &dmax) const {
  dmin = -99.9;
  dmax = 99.9;
  for (Uint fam = 0; fam < nfam(); fam++) {
    const Double zmin = (minvalue[fam] - nullmean[0][fam])/nullsd[0][fam];
    const Double zmax = (maxvalue[fam] - nullmean[0][fam])/nullsd[0][fam];
    dmin = max_(dmin, -1.0/zmax/weight[0][fam]);
    dmax = min_(dmax, -1.0/zmin/weight[0][fam]);
  }
}

void ASMpoly::allegrolinear() {
  Double dmin, dmax;
  if (degree == 1) {
    linearbounds(dmin, dmax);
    if (options->deltaboundfile.assigned()) {
      options->deltaboundfile << distribution->describe() << "\t";
      fmtout(options->deltaboundfile, 14, 12, dmin);
      fmtout(options->deltaboundfile, 14, 12, dmax);
      options->deltaboundfile << "\n";
    }
  }
  else {
    dmin = -5;
    dmax = 5;
  }
  Maximizor mz(dmin, dmax, 100, 11, degree != 1);
  polycaller = this;
  for(Uint pos = 0; pos < npos(); pos++) {
    if (pt == "spt" ||
        map->inbetween[pos] || map->shouldfindp[map->leftmarker[pos]]) {
      mz.maximize(0, 0, f_lin, pos, dhat[pos], lod[pos]);
      lod[pos] /= ALLEGROLOG10;
    }
  }
  // compute_fam_lod
  for(Uint pos = 0; pos < npos(); pos++)
    for(Uint fam = 0; fam < nfam(); fam++) {
      Double L = 1.0;
      Double L0 = 1.0;
      Double xx = dhat[pos]*weight[pos][fam];
      for (Uint deg = 0; deg < degree; deg++) {
        L += xx*zvalue[fam][pos][deg];
        L0 += xx*znullvalue[pos][fam][deg];
        xx *= dhat[pos]*weight[pos][fam]/Double(deg + 2);
      }
      flod[fam][pos] = log(L/L0)/log(10.0);
#ifdef DECODE
      assertinternal(finite(flod[fam][pos]));
#endif
    }
}

void ASMpoly::compute_fam_info() const {
  for (Uint fam = 0; fam < nfam(); fam++) {
    for (Uint pos = 0; pos < npos(); pos++) {
      Double mean = zvalue[fam][pos][0];
      Double var = zvalue[fam][pos][1] - mean*mean;
      faminfo[fam][pos] = max_<Double, Double>(-99.0, 1 - var);
    }
  }
}

void ASMpoly::run() {
  compute_fam_npl();
  allegrolinear();
  compute_final_stats();
  allegroinfo();
}

void ASMpoly::printfaminfo() const {
  Outfile faminfo;
  string on = outfile.name;
  string::size_type lastslash = on.find_last_of("/");
  string fname;
  if (lastslash == string::npos) fname = "i" + on;
  else fname = (on.substr(0, lastslash + 1) + "i" + on.substr(lastslash + 1));
  faminfo.setname(fname, fname);
  faminfo.open();
  faminfo << modelweight->describe() << "\n";
  faminfo << "Family\tweight\tzmin\tzmax" << "\n";
  for (unsigned int fam = 0; fam < nfam(); fam++) {
    Double zmin = (nullsd[0][fam] == 0 ? 0 :
                   (minvalue[fam] - nullmean[0][fam])/nullsd[0][fam]);
    Double zmax = (nullsd[0][fam] == 0 ? 0 :
                   (maxvalue[fam] - nullmean[0][fam])/nullsd[0][fam]);
    faminfo << getfamid(fam) << "\t"
            << modelweight->assign(getfamid(fam), nullsd[0][fam]) << "\t"
            << zmin << "\t" << zmax << "\n";
  }
  faminfo.close();
}
