#include "findibd.h"
#include "pairibd.h"
#include "family.h"
#include "vecutil.h"

bool hasinformativedescendants(Person *p) {
  return true;
  if (p->origdstat != UNKNOWN) return true;
  for (Plist *c = p->children; c != 0; c = c->next)
    if (hasinformativedescendants(c->p)) return true;
  return false;
}

FindIBD::FindIBD(Family *f, const string &pairs) : fam(f), pairtype(pairs) {
  foundercoupleindices.clear();
  firstibdperson = 0;
  Uint idx = 0;
  mask = 0;
  for (Person *p = fam->first; p != 0; p = p->next) {
    if (hasinformativedescendants(p)) {
      IBDperson *pibd = new IBDperson(p, firstibdperson, fam->num - 1,
                                      idx, isinformative(p));
      if (firstibdperson == 0) firstibdperson = pibd;
    }
    else mask |= p->patmask | p->matmask;
  }
  for (Foundercouple *fc = fam->firstfoundercouple; fc != 0; fc = fc->next) {
    Uint gfidx = firstibdperson->findperson(fc->husband)->get_index();
    Uint gmidx = firstibdperson->findperson(fc->wife)->get_index();
    foundercoupleindices.push_back(Intpair(gfidx, gmidx));
  }
}

FindIBD::~FindIBD() {
  delete firstibdperson;
}

void FindIBD::findprior(DoubleVec results1, DoubleVec results2) {
  firstibdperson->reset();
  firstibdperson->calcpairibd(0, mask);
  for (Uint i = 0; i < foundercoupleindices.size(); i++)
    firstibdperson->correctfoundercouple(foundercoupleindices[i].first,
                                         foundercoupleindices[i].second);
  if (results1 != 0 && results2 != 0)
    firstibdperson->collectresults(firstibdperson, results1, results2, 
                                   fam->numiv/POW2[wt[mask]]);  
}

void FindIBD::findposterior(FloatVec prob,
                             DoubleVec results1, DoubleVec results2) {
  firstibdperson->reset();
  firstibdperson->calcpairibd(0, mask, prob);
  for (Uint i = 0; i < foundercoupleindices.size(); i++)
    firstibdperson->correctfoundercouple(foundercoupleindices[i].first,
                                         foundercoupleindices[i].second);
  Double nc = sum<Double>(prob, fam->numiv);
  firstibdperson->collectresults(firstibdperson, results1, results2, nc);
}

void FindIBD::collectpairs(StringVec person1, StringVec person2) const {
  firstibdperson->collectpairs(firstibdperson, person1, person2);
}

bool FindIBD::isinformative(Person *p) {
  if (pairtype == "all") return true;
  else if (pairtype == "genotyped") return p->genotyped;
  else if (pairtype == "affected") return p->dstat == AFFECTED;
  else if (pairtype == "qtl") return p->hastraitvalue();
  else if (pairtype == "informative") return p->dstat != UNKNOWN;
  else assertinternal(false);
}

Uint FindIBD::countpairs() {
  return firstibdperson->countpairs(firstibdperson);
}

void FindIBD::priorresult(Person *a, Person *b, Float &p1, Float &p2) const {
  assertinternal(a != b);
  IBDperson *ap = firstibdperson->findperson(a);
  IBDperson *bp = firstibdperson->findperson(b);
  assertinternal(ap != 0 && bp != 0);
  ap->getsharing(bp, p1, p2);
  p1 /= fam->numiv/POW2[wt[mask]];
  p2 /= fam->numiv/POW2[wt[mask]];
}
