#include "files.h"
#include "assocdist.h"
#include "calcassoc.h"
#include "findassoc.h"
#include "control.h"

///////////////////////////////////////////////////////////////////////////
// Assocdist

Assocdist::Assocdist(const string &p, CorrectionType ct, Calcassoc *aw) :
    Distribution(p), corr(ct), assocweight(aw) {}

Assocdist::~Assocdist() {
  delete first;
}

void Assocdist::nextfam(Uint pos, DoubleVec /*p0*/) {
  assertinternal(pos == 0);
  familydata.push_back(new Familydata(map->num));
  Familydata *famdat = familydata.back();
  for (Uint gam = 0; gam < famdat->nmrk; gam++) {
    NEWVEC(Double, famdat->ET[gam], map->numallele[gam]);
    NEWVEC(Double, famdat->ET2[gam], map->numallele[gam]);
    NEWVEC(Double, famdat->ET3[gam], map->numallele[gam]);
    zero(famdat->ET[gam], map->numallele[gam]);
    zero(famdat->ET2[gam], map->numallele[gam]);
    zero(famdat->ET3[gam], map->numallele[gam]);
  }
  
  // Setup Calcassoc stuff for actual calculations
  delete first;
  Person *firstper = curfamily()->first;
  first = new Calcassocperson(firstper, 0);
  for (Person *p = curfamily()->first->next; p != 0; p = p->next)
    new Calcassocperson(p, first);
}

Double Assocdist::calcET(Uint gam, Uint a) {
  Double ET = 0.0;
  for (Person *p = curfamily()->first; p != 0; p = p->next)
    ET += 2*assocweight->assign(p, gam)*map->pi[0][gam][a];
  return ET;
}

void Assocdist::set(FloatVec pv, Uint pos) {
  Uint gam;
  if (pt == "mpt") {
    if (map->inbetween[pos]) return;
      else gam = map->leftmarker[pos];
  }
  else gam = pos;
  Familydata *famdat = familydata.back();
  if (corr == RAW) {
    Double sum_w2 = 0.0;
    Double *pi = assocweight->getfreq(gam);
    for (Person *p = curfamily()->first; p != 0; p = p->next)
      sum_w2 += pow(assocweight->assign(p, gam), 2.0);
    for (Uint a = 0; a < map->numallele[gam]; a++) {
      Double mu = famdat->ET[gam][a] = calcET(gam, a);
      famdat->ET2[gam][a] = sum_w2*2.0*pi[a]*(1.0 - pi[a]) + mu*mu;
      famdat->ET3[gam][a] = .0; /***********/
    }    
  }
  else {
    // Set person weights
    Calcassocperson *cap = first;
    for (Person *p = curfamily()->first; p != 0;
         p = p->next, cap = (Calcassocperson *)cap->next) {
      assertinternal(cap != 0);
      cap->weight = assocweight->assign(p, gam);
    }
    // Set allele frequencies
    Calcassocnode::setfreq(assocweight->getfreq(gam), map->numallele[gam]);
    Calcassocperson::resetETM(map->numallele[gam]);
    first->calcmoments(0, (corr == FAMILY ? 0 : pv), famdat->ET2[gam],
                       famdat->ET3[gam], map->numallele[gam], gam);
    for (Uint a = 0; a < map->numallele[gam]; a++) {
      if (corr == FAMILY) {
        famdat->ET2[gam][a] /= Double(curfamily()->numiv);
        famdat->ET3[gam][a] /= Double(curfamily()->numiv);
      }
      Double mu = famdat->ET[gam][a] = calcET(gam, a);
      Double x2 = famdat->ET2[gam][a];
      famdat->ET2[gam][a] += mu*mu;
      famdat->ET3[gam][a] += 3.0*x2*mu + mu*mu*mu;
    }
  }
}
  
void Assocdist::reset(Uint np) {
  while (!familydata.empty()) {
    delete familydata.back();
    familydata.pop_back();
  }
  npos = np;
}

void Assocdist::skipfam() {
  if (!familydata.empty()) {
    delete familydata.back();
    familydata.pop_back();
  }
}

Assocdist *Assocdist::getassocdist(const string &pt, CorrectionType ct,
                                   Calcassoc *aw) {
  for (Uint d = 0; d < distributions.size(); d++)
    if (distributions[d]->describe() ==
        string("assoc") + correctiontypetostring(ct) &&
        distributions[d]->getpt() == pt) {
      Assocdist *ad = (Assocdist *)distributions[d];
      if (ad->assocweight->describe() == aw->describe())
        return ad;
    }
  return new Assocdist(pt, ct, aw);
}

bool Assocdist::usefamily(Family *fam) const {
  Uint ninf = 0;
  for (Person *p = fam->first; p != 0; p = p->next) 
    if (p->origdstat != UNKNOWN) ninf++;
  return ninf > 0;
}
