#include "files.h"
#include "scoredistmoments.h"
#include "family.h"
#include "calcscore.h"

Scoredistmoments::Scoredistmoments(const string &p, Uint deg) :
    Scoredist(p, deg) {}

Scoredistmoments::~Scoredistmoments() {
  while (!familydata.empty()) {
    delete familydata.back();
    familydata.pop_back();
  }
}

void Scoredistmoments::nextfam(Uint pos, DoubleVec p0) {
  if (pos == 0) familydata.push_back(new Familydata(this));
  Familydata *famdat = familydata.back();

  IV numiv = curfamily()->numiv;
  Uint numscores = numiv*POW2[curfamily()->nasymmetricfoundercouples()];

  famdat->minvalue = 1e300;
  famdat->maxvalue = -1e300;
  for (IV v = 0; v < numscores; v++) {
    Double val = calcscore->vec[v];
    Double valp = val;

    if (p0 != 0)
      valp *= p0[v % numiv];
    famdat->nullmean[pos] += valp;
    famdat->nullsd[pos] += valp*val;

    if (pos == 0) {
      if (val < famdat->minvalue) famdat->minvalue = val;
      if (val > famdat->maxvalue) famdat->maxvalue = val;
    }
  }
  if (isnullconstant()) {
    famdat->nullmean[pos] /= Double(numscores);
    famdat->nullsd[pos] /= Double(numscores);
  }
  famdat->nullsd[pos] = (famdat->nullsd[pos] -
                         famdat->nullmean[pos]*famdat->nullmean[pos]);
  famdat->nullsd[pos] =  (famdat->nullsd[pos] < 1e-8 ? 0 :
                          sqrt(famdat->nullsd[pos]));

  famdat->nullvalue[pos][0] = 0;
  famdat->nullvalue[pos][1] = 1;
  if (famdat->nullsd[pos] == 0)
    for (Uint deg = 2; deg < degree; deg++)
      famdat->nullvalue[pos][deg] = 0;
  else {
    for (IV v = 0; v < numscores; v++) {
      const Double val = (calcscore->vec[v] -
                          famdat->nullmean[pos])/famdat->nullsd[pos];
      Double valp = val*val*val;

      if (p0 != 0)
        valp *= p0[v % numiv];
      for (Uint deg = 2; deg < degree; deg++) {
        famdat->nullvalue[pos][deg] += valp;
        valp *= val;
      }
    }
  }
  if (isnullconstant()) 
    for (Uint deg = 2; deg < degree; deg++)
      famdat->nullvalue[pos][deg] /= Double(numscores);
}

void Scoredistmoments::set(FloatVec pv, Uint pos) {
  Familydata *famdat = familydata.back();
  Uint nullpos = (isnullconstant() ? 0 : pos);
  if (famdat->nullsd[nullpos] == 0) return;
  for (Uint a = 0; a < POW2[curfamily()->nasymmetricfoundercouples()]; a++) {
    DoubleVec sv = calcscore->vec + a*curfamily()->numiv;
    for (IV v = 0; v < curfamily()->numiv; v++) {
      Double val = (sv[v] - famdat->nullmean[nullpos])/famdat->nullsd[nullpos];
      Double valp = val*pv[v];
      for (Uint deg = 0; deg < degree; deg++) {
        famdat->value[pos][deg] += valp;
        valp *= val;
      }
    }
  }
  if (curfamily()->nasymmetricfoundercouples() > 0) {
    const Double fac =
      1./Double(POW2[curfamily()->nasymmetricfoundercouples()]);
    for (Uint deg = 0; deg < degree; deg++)
      famdat->value[pos][deg] *= fac;
  }
}

void Scoredistmoments::reset(Uint np) {
  while (!familydata.empty()) {
    delete familydata.back();
    familydata.pop_back();
  }
  npos = np;
}

void Scoredistmoments::getresults(DoubleMat *value, DoubleMat *nullvalue,
                                  DoubleMat nullmean, DoubleMat nullsd,
                                  Uint deg, DoubleVec minvalue,
                                  DoubleVec maxvalue) {
  for (Uint pos = 0; pos < npos; pos++)
    for (Uint fam = 0; fam < familydata.size(); fam++)
      copyvec(value[fam][pos], familydata[fam]->value[pos], deg);

  for (Uint pos = 0; pos < nnulldist(); pos++) {
    for (Uint fam = 0; fam < familydata.size(); fam++) {
      copyvec(nullvalue[pos][fam], familydata[fam]->nullvalue[pos], deg);
      nullmean[pos][fam] = familydata[fam]->nullmean[pos];
      nullsd[pos][fam] = familydata[fam]->nullsd[pos];
    }
  }

  for (Uint fam = 0; fam < familydata.size(); fam++) {
    minvalue[fam] = familydata[fam]->minvalue;
    maxvalue[fam] = familydata[fam]->maxvalue;
  }
}

void Scoredistmoments::skipfam() {
  if (!familydata.empty()) {
    delete familydata.back();
    familydata.pop_back();
  }
}
