#include "files.h"
#include "family.h"
#include "distribution.h"
#include "viterbidist.h"
#include "map.h"
#include "options.h"
#include "utils.h"
#include "vecutil.h"
#include "calcviterbi.h"

Viterbidist::Viterbidist(const string &p) : 
    Distribution(p), path(0), ntmpfc(10), ntmp(10) {
  calc = Calcviterbi::getcalcviterbi(p);
  NEWVEC(Float, tmpfloat, POW2[ntmp]);
  NEWVEC(Float, tmpfcfloat, POW2[ntmpfc]);
  NEWVEC(IV, tmpiv, POW2[ntmp]);
  NEWVEC(IV, tmpfciv, POW2[ntmpfc]);
}

void Viterbidist::nextfam(Uint pos, DoubleVec /*p0*/) {
  assertinternal(pos == 0);
  calc->psi->nextfam(map->num*curfamily()->numiv, curfamily()->numiv);
  // Calculate how big tmp vectors have to be
  Uint curtmp = 0;
  Uint curtmpfc = 0;
  Family *fam = curfamily();
  for (Foundercouple *fc = fam->firstfoundercouple; fc != 0; fc = fc->next) {
    Uint ngc = 0;
    for (Plist *c = fc->wife->children; c != 0; c = c->next)
      if (c->p->children != 0) ngc += c->p->children->length();
    curtmpfc = max_(curtmpfc, 2*(fc->wife->children->length() - 1) + ngc - 1);
  }
  for (Person *p = fam->first; p != fam->firstdescendant; p = p->next)
    if (p->children != 0) curtmp = max_(curtmp, p->children->length() - 1);
  // Resize tmp vectors if needed
  if (curtmp > ntmp) {
    ntmp = curtmp;
    DELETEVEC(tmpfloat);
    DELETEVEC(tmpiv);
    NEWVEC(Float, tmpfloat, POW2[ntmp]);
    NEWVEC(IV, tmpiv, POW2[ntmp]);
  }
  if (curtmpfc > ntmpfc) {
    ntmpfc = curtmpfc;
    DELETEVEC(tmpfcfloat);
    DELETEVEC(tmpfciv);
    NEWVEC(Float, tmpfcfloat, POW2[ntmpfc]);
    NEWVEC(IV, tmpfciv, POW2[ntmpfc]);
  }
}

Viterbidist::~Viterbidist() {
  DELETEVEC(path);
  DELETEVEC(tmpfloat);
  DELETEVEC(tmpfcfloat);
  DELETEVEC(tmpiv);
  DELETEVEC(tmpfciv);
}

void Viterbidist::reset(Uint np) {
  DELETEVEC(path);
  NEWVEC(IV, path, np);
  npos = np;
}

void Viterbidist::gettheta(Uint gam, Float &tht, Float &tht_female) const {
  if (options->sexspecific) {
    tht = map->theta[1][gam];
    tht_female = map->theta[2][gam];
  }
  else {
    tht = map->theta[0][gam];
    tht_female = -1;
  }
}

void Viterbidist::set(FloatVec q, Uint gam) {
  Family *fam = curfamily();
  if (gam == npos - 1) copyvec(calc->vec, q, fam->numiv);
  else {
    IVVec h = calc->psi->getrow(gam, false);
    for (IV v = 0; v < fam->numiv; v++) h[v] = v;
    Float tht, tht_female;
    gettheta(gam, tht, tht_female);
    step(calc->vec, h, tht, tht_female);
    calc->psi->store(gam);
    elemprod(calc->vec, calc->vec, q, fam->numiv);
    normal<Float>(calc->vec, fam->numiv);
    assertinternal(calc->vec[0] >= 0);
  }
  if (gam == 0) {
    Float deltamax = -1;
    for (IV v = 0; v < fam->numiv; v++)
      if (deltamax < calc->vec[v]) {
        deltamax = calc->vec[v];
        path[0] = v;
      }
    assertinternal(deltamax != -1);
    for (Uint g = 1; g < npos; g++)
      path[g] = calc->psi->getrow(g - 1, true)[path[g - 1]];
  }
}

inline void Viterbidist::dobits(FloatVec p, IVVec h, Float fac,
                               IV frombit, IV tobit, IV N) {
  for (IV k = frombit; k <= tobit; k <<= 1) {
    IV u = 0;
    while (u < N) {
      for (IV v = u; v < u + k; v++)
        dobit(p[v], p[v | k], h[v], h[v | k], fac);
      u += k << 1;
    }
  }  
}

inline void Viterbidist::dofounderloop(FloatVec p0, IVVec h, Float fac, IV base,
                                      IV mask, IV lastbit) {
  for (IV k = base, i = 0; k <= mask + base; i++, k += lastbit) {
    tmpfloat[i] = p0[k];
    tmpiv[i] = h[k];
  }
  Uint N = mask/lastbit;
  // First assume there is not a recombination in first child
  dobits(tmpfloat, tmpiv, fac, 1, N, N + 1);
  // Then assume that there is one
  for (IV k = lastbit; k <= mask; k <<= 1) {
    IV u = base;
    while (u <= mask + base) {
      for (IV w = u; w < u + k; w += lastbit)
        doinversebit(p0[w], p0[w | k], h[w], h[w | k], fac);
      u += k << 1;
    }
  }  
  for (IV k = base, i = 0; k <= mask + base; i++, k += lastbit) {
    p0[k] *= fac;
    if (tmpfloat[i] > p0[k]) {
      p0[k] = tmpfloat[i];
      h[k] = tmpiv[i];
    }
  }
}

inline void Viterbidist::dofounder(FloatVec p0, IVVec h, Float fac, IV N,
                                  IV mask, IV lastbit) {
  IV u = 0;
  if (lastbit == 0) lastbit = N;
  while (u < N) {
    for (IV v = u; v < lastbit + u; v++) {
      dofounderloop(p0, h, fac, v, mask, lastbit);
    }
    u += mask + lastbit;
  }
}

void Viterbidist::calculatefcpi(IV base, FloatVec p0, IVVec hp, Foundercouple *fc) {
  Uint n = wt[fc->wife->mask];
  // First loop over bits of husband and wife
  for (IV w = 0; w <= fc->wife->mask; w += fc->wife->lastbit)
    for (IV h = (w >> n) /*+ fc->husband->lastbit*/;
         h <= fc->husband->mask; h += fc->husband->lastbit) {
      IV m = h | w;
      IV pim = fc->pi(m);
      // Then loop over bits of fc's grandchildren
     IV to, inc;
      if (m != pim) {
        to = fc->mask + fc->lastgcbit;
        inc = fc->lastgcbit;
      }
      else if (wt[fc->mask] == 1) {
        to = fc->mask;
        inc = fc->mask;
      }
      else {
        to = (fc->mask + fc->lastgcbit)/2;
        inc = fc->lastgcbit;
      }
      for (IV v = base; v < to + base; v += inc) {
        IV a = m | v;
        IV b = pim | (v^fc->mask);
        Float tp = p0[a];
        IV th = hp[a];
        p0[a] = p0[b];
        p0[b] = tp;
        hp[a] = hp[b];
        hp[b] = th;
      }
    }
}

inline void Viterbidist::dofcfounders(FloatVec p0, IVVec h, Float fac, Uint wm,
                                     IV wb, IV hm, IV hb,
                                     IV start, IV inc, IV end) {
  for (IV u = start; u < end; u += inc) {
    for (IV i = u; i <= u + wm; i += wb)
      dofounderloop(p0, h, fac, i, hm, hb);
    for (IV i = u; i <= u + hm; i += hb)
      dofounderloop(p0, h, fac, i, wm, wb);
  }
}

inline void Viterbidist::dofc(FloatVec p, IVVec h, Float fac, IV frombit, IV tobit,
                             IV start, IV inc, IV end) {
  IV N = frombit + tobit;
  for (IV k = frombit; k <= tobit; k <<= 1) {
    IV u = 0;
    while (u < N) {
      for (IV w = u; w < u + k; w += frombit)
        for (IV v = start + w; v < end + w; v += inc)
          dobit(p[v], p[v | k], h[v], h[v | k], fac);
      u += k << 1;
    }
  }
}

void makefcmask(IV num, Uint n, IV &wm, IV &hm, IV &wb, IV &hb) {
  wm = hm = wb = hb = 0;
  for (Uint i = 0; i < n; i++) {
    num /= 2;
    wm += num;
    wb = num;
  }
  for (Uint i = 0; i < n; i++) {
    num /= 2;
    hm += num;
    hb = num;
  }
}

void Viterbidist::step(FloatVec p0, IVVec h, Float tht, Float tht_female) {
  const Float fac = tht/(1.0 - tht);
  const Float fac_female = tht_female/(1.0 - tht_female);
  Family *fam = curfamily();
  // First take care of founder couple bits
  for (Foundercouple *fc = fam->firstfoundercouple; fc != 0; fc = fc->next) {
    assertinternal(!options->sexspecific);
    for (IV k1 = 0; k1 < fam->numiv; k1 += fc->childrenmask + fc->lastcbit())
      for (IV k2 = k1; k2 < k1 + fc->lastcbit();
           k2 += fc->mask + fc->lastgcbit)
        for (IV k3 = k2; k3 < k2 + fc->lastgcbit; k3++) {
          // Copy to temporary buffers
          IV i = 0;
          // First assume there is not a recombination in the founder
          // couple grandchild
          for (IV v = 0; v <= fc->childrenmask; v += fc->lastcbit())
            // Go through the bits of the children of the fc
            for (IV u = v; u <= v + fc->mask; u += fc->lastgcbit) {
              // Go through the bits of the grandchildren of the fc
              tmpfcfloat[i] = p0[k3 + u];
              tmpfciv[i] = h[k3 + u];
              i++;
            }
          IV wm, hm, wb, hb;
          makefcmask(i, wt[fc->wife->mask], wm, hm, wb, hb);
          dofounder(tmpfcfloat, tmpfciv, fac, i, wm, wb);
          dofounder(tmpfcfloat, tmpfciv, fac, i, hm, hb);
          dobits(tmpfcfloat, tmpfciv, fac, 1,
                 (hb == 0 ? POW2[wt[fc->mask] - 1] : hb - 1), i);
          // Then assume there is a recombination in the founder
          // couple grandchild
          calculatefcpi(k3, p0, h, fc);
          dofcfounders(p0, h, fac, fc->wife->mask, fc->wife->lastbit,
                       fc->husband->mask, fc->husband->lastbit,
                       k3, fc->lastgcbit, fc->mask + fc->lastgcbit);
          dofc(p0, h, fac, fc->lastgcbit, fc->mask,
               k3, fc->husband->lastbit, fc->wife->mask + fc->wife->lastbit);
          // Compare results
          i = 0;
          for (IV v = 0; v <= fc->childrenmask; v += fc->lastcbit())
            // Go through the bits of the children of the fc
            for (IV u = v; u <= v + fc->mask; u += fc->lastgcbit) {
              // Go through the bits of the grandchildren of the fc
              p0[k3 + u] *= fac;
              if (tmpfcfloat[i] > p0[k3 + u]) {
                p0[k3 + u] = tmpfcfloat[i];
                h[k3 + u] = tmpfciv[i];
              }
              i++;
            }
        }
  }
  // Next take care of founder bits
  for (Person *p = fam->first; p != fam->firstdescendant; p = p->next)
    if (p->mask != 0 && p->fc == 0) {
      if (!options->sexspecific || p->sex == MALE) 
        dofounder(p0, h, fac, fam->numiv, p->mask, p->lastbit);
      else
        dofounder(p0, h, fac_female, fam->numiv, p->mask, p->lastbit);
    }
  // Finally take care of the rest of the bits
  if (options->sexspecific) {
    for (Person *p = fam->firstdescendant; p != 0; p = p->next) {
      for (Plist *c = p->children; c != 0; c = c->next)
        if (p->sex == MALE && c->p->patmask != 0)
          dobits(p0, h, fac, c->p->patmask, c->p->patmask, fam->numiv);
        else // (p->sex == FEMALE && c->p->matmask != 0)
          dobits(p0, h, fac_female, c->p->matmask, c->p->matmask, fam->numiv);
    }
  }
  else
    dobits(p0, h, fac, 1, fam->mask, fam->numiv);
}

Uint Viterbidist::countfounderrecombs(IV mask, IV v0, IV v1,
                                      Uint &same, Uint &diff) const {
  // Calc number of bits unchanged
  same = wt[~(v1 ^ v0) & mask];
  // Calc number of bits changed
  diff = wt[mask] - same;
  return max_(int(same), int(diff) - 1);
}

Uint Viterbidist::countfcrecombs(Foundercouple *fc, IV v0, IV v1,
                                 Uint &same, Uint &diff) const {
  same = diff = 0;
  Uint a, b;
  // Calc number of recombinations if there IS NOT a recomb in gc
  same += countfounderrecombs(fc->wife->mask, v0, v1, a, b);
  same += countfounderrecombs(fc->husband->mask, v0, v1, a, b);
  countfounderrecombs(fc->mask, v0, v1, a, b);
  same += a;
  diff += b;
  // Calc number of recombinations if there IS a recomb in gc
  Uint shift = wt[fc->wife->mask];
  diff += countfounderrecombs(fc->husband->mask, v0 >> shift, v1, a, b);
  diff += countfounderrecombs(fc->husband->mask, v0, v1 >> shift, a, b);
  return max_(int(same), int(diff) - 1);
}
                   
Uint Viterbidist::countrecombs(Family *f, IV v0, IV v1) const {
  Uint recs = 0;
  Uint same, diff;
  for (Foundercouple *fc = f->firstfoundercouple; fc != 0; fc = fc->next)
    recs += wt[fc->mask | fc->husband->mask | fc->wife->mask] -
      countfcrecombs(fc, v0, v1, same, diff);
  for (Person *p = f->first; p != 0; p = p->next)
    if (p->children != 0 && p->fc == 0 && p->founder() &&
        (!options->sexlinked || p->sex == FEMALE))
      recs += wt[p->mask] - countfounderrecombs(p->mask, v0, v1, same, diff);
  recs += wt[f->mask] - wt[~(v1 ^ v0) & f->mask];
  return recs;
}
