#include "pairibd.h"
#include "family.h"
#include "vecutil.h"
#include "files.h"
#include "options.h"

IBDperson::IBDperson(Person *p, IBDperson *first, Uint lastrel,
                     Uint &idx, bool infm) :
    Pairwiseperson<Indexfounderallele>(p, first), informative(infm),
    numibd(lastrel) {
  if (informative) {
    index = idx++;
    NEWVEC(Double, ibd1, numibd);
    NEWVEC(Double, ibd2, numibd);
  }
  else {
    ibd1 = ibd2 = 0;
    index = Uint(-1);
  }
  if (per->founder()) {
    nod[0]->setmaxnum(numibd);
    if (!options->sexlinked || per->sex == FEMALE)
      nod[1]->setmaxnum(numibd);
  }
}

IBDperson::~IBDperson() {
  DELETEVEC(ibd1);
  DELETEVEC(ibd2);
}

void IBDperson::updateibd(Double curprob) {
  Uint i0 = 0, i1 = 0;
  if (!options->sexlinked || per->sex == FEMALE) {
    if (nod[0] != nod[1]) {
      while (i0 < nod[0]->numdescendants && i1 < nod[1]->numdescendants) {
        Uint d0 = nod[0]->descendants[i0];
        Uint d1 = nod[1]->descendants[i1];
        if (d0 < d1) {
          ibd1[d0] += curprob;
          i0++;
        }
        else if (d1 < d0) {
          ibd1[d1] +=curprob;
          i1++;
        }
        else {
          ibd2[d0] += curprob;
          i0++;
          i1++;
        }
      }
      while (i0 < nod[0]->numdescendants) {
        Uint d0 = nod[0]->descendants[i0];
        ibd1[d0] += curprob;
        i0++;
      }
    }
    else {
      while (i1 < nod[1]->numdescendants) {
        Uint d1 = nod[1]->descendants[i1];
        if (nod[1]->descendanthomozygous[i1])
          ibd2[d1] += curprob;
        else
          ibd1[d1] += curprob;
        i1++;
      }
    }
  }
  while (i1 < nod[1]->numdescendants) {
    Uint d1 = nod[1]->descendants[i1];
    ibd1[d1] += curprob;
    i1++;
  }
}

Double probmasksum(IV v, IV mask, DoubleVec pv, IV k) {
  while (k <= mask && !(k & mask)) k <<= 1;
  if (k <= mask) {
    IV l = k << 1;
    return probmasksum(v, mask, pv, l) + probmasksum(v | k, mask, pv, l);
  }
  else return pv[v];
}

Double IBDperson::calcpairibd(IV v, IV mask, DoubleVec prob) {
  int Kf = per->patmask ? 1 : 0;
  int Km = per->matmask ? 1 : 0;

  Double cumprob = 0.0;
  for (int K1 = 0; K1 <= Km; K1++) {
    if (K1) v += per->matmask;
    if (mother != 0) nod[1] = mother->nod[K1];
    for (int K0 = 0; K0 <= Kf; K0++) {
      if (K0) v += per->patmask;
      if (father != 0 && (!options->sexlinked || per->sex == FEMALE))
        nod[0] = father->nod[K0];
      else if (father != 0) nod[0] = nod[1];
      Double curprob;
      if (informative) {
        if (next == 0) curprob = (prob == 0 ? 1.0 : 
                                  probmasksum(v, mask, prob, 1));
        else {
          bool homozygote = nod[0] == nod[1];
          nod[0]->pushdescendant(index, homozygote);
          if ((!options->sexlinked || per->sex == FEMALE) && !homozygote)
            nod[1]->pushdescendant(index, false);
          curprob = ((IBDperson *)next)->calcpairibd(v, mask, prob);
          nod[0]->popdescendant();
          if ((!options->sexlinked || per->sex == FEMALE) && !homozygote)
            nod[1]->popdescendant();
        }
        if (curprob > 0) updateibd(curprob);
      }
      else {
        if (next == 0) curprob = (prob == 0 ? 1.0 :
                                  probmasksum(v, mask, prob, 1));
        else curprob = ((IBDperson *)next)->calcpairibd(v, mask, prob);
      }        
      cumprob += curprob;
    }
    v &= ~per->patmask;
  }
  return cumprob;
}

void IBDperson::collectpairs(IBDperson *first,
                             StringVec person1, StringVec person2) const {
  Uint i = 0;
  if (informative) {
    for (IBDperson *q = first; q->per != per; q = (IBDperson *)q->next)
      if (q->informative) {
        person1[i] = per->id;
        person2[i] = q->per->id;
        i++;
      }
  }
  if (next != 0)
    ((IBDperson *)next)->collectpairs(first, person1 + i, person2 + i);
}

void IBDperson::collectresults(IBDperson *first, DoubleVec results1,
                               DoubleVec results2, Double nc) const {
  Uint i = 0;
  if (informative) {
    for (IBDperson *q = first; q->per != per; q = (IBDperson *)q->next)
      if (q->informative) {
        results1[i] = ibd1[q->index]/nc;
        results2[i] = ibd2[q->index]/nc;
        i++;
      }
  }
  if (next != 0) ((IBDperson *)next)->collectresults(first, results1 + i,
                                      results2 + i, nc);
}

void IBDperson::reset() {
  if (ibd1 != 0) {
    zero(ibd1, numibd);
    zero(ibd2, numibd);
  }
  if (next != 0) ((IBDperson *)next)->reset();
}

Uint IBDperson::countpairs(IBDperson *first) const {
  Uint i = 0;
  if (informative)
    for (IBDperson *p = first; p->per != per; p = (IBDperson *)p->next)
      if (p->informative) i++;
  return i + (next == 0 ? 0 : ((IBDperson *)next)->countpairs(first));
}

IBDperson *IBDperson::findperson(Person *p) {
  if (per == p) return this;
  else {
    assertinternal(next != 0);
    return ((IBDperson *)next)->findperson(p);
  }
}

void IBDperson::correctfoundercouple(Uint gfidx, Uint gmidx) {
  if (!per->founder() && informative) {
    if (gfidx != Uint(-1))
      ibd1[gfidx] = ibd1[gmidx] = .5*(ibd1[gfidx] + ibd1[gmidx]);
    if (gmidx != Uint(-1))
      ibd2[gfidx] = ibd2[gmidx] = .5*(ibd2[gfidx] + ibd2[gmidx]);
  }
  if (next != 0) ((IBDperson *)next)->correctfoundercouple(gfidx, gmidx);
}

void IBDperson::getsharing(IBDperson *p, Float &p1, Float &p2) {
  if (!informative || !p->informative) p1 = p2 = 0;
  else {
    if (p->index < index) {
      p1 = ibd1[p->index];
      p2 = ibd2[p->index];
    }
    else {
      p1 = p->ibd1[index];
      p2 = p->ibd2[index];
    }
  }
}
