#include <iostream>
#include <fstream>
#include <iomanip>
#include "files.h"
#include "recombdist.h"
#include "haldane.h"
#include "map.h"
#include "family.h"

Recombdist *Recombdist::getrecombdist() {
  for (Uint d = 0; d < distributions.size(); d++)
    if (distributions[d]->describe() == "recomb")
      return (Recombdist *)distributions[d];
  return new Recombdist();
}

Recombdist::Recombdist() : Distribution("rpt") {
  NEWVEC(Float, thfactor, 63);
  NEWVEC(Float, thpower, 63);
  thf = thfactor + 31;
  thp = thpower + 31;
  NEWVEC(Float, thfactor_female, 63);
  thf_female = thfactor_female + 31;
}

Recombdist::~Recombdist() {
  reset(npos);
}

void Recombdist::reset(Uint np) {
  npos = np;
  while (!familydata.empty()) {
    delete familydata.back();
    familydata.pop_back();
  }
}

Uint Recombdist::bitsum_male() const {
  Uint sum = 0;
  for (Uint ifa = 0; ifa < familydata.size(); ifa++)
    sum += familydata[ifa]->nbits_male;
  return sum;
}

Uint Recombdist::bitsum_female() const {
  Uint sum = 0;
  for (Uint ifa = 0; ifa < familydata.size(); ifa++)
    sum += familydata[ifa]->nbits_female;
  return sum;
}

Float Recombdist::xoversum(Uint pos) const {
  Double sum = 0;
  for (Uint ifa = 0; ifa < familydata.size(); ifa++)
    sum += familydata[ifa]->xover[pos];
  return sum;
}

Float Recombdist::xoversum_female(Uint pos) const {
  Double sum = 0;
  for (Uint ifa = 0; ifa < familydata.size(); ifa++)
    sum += familydata[ifa]->xover_female[pos];
  return sum;
}

void Recombdist::addtosum(Double& sum, Double& C, Double prod, Double count,
                          IV r, Person *f, IV maskJ, IV masknotJ, Uint nJ) {
  if (f->founder()) {
    Uint s = wt[f->mask] + 1;
    for (IV v = 0; v <= f->mask; v += f->lastbit) {
      Uint w = wt[v & f->mask];
      Double b = thf[s - 2*w];
      Double d = (w + b*(s - w))/(1.0 + b);
      addtosum(sum, C, prod*(1.0 + b), d + count,
               r | v, f->next, maskJ, masknotJ, nJ);
    }
  }
  else {
    Family *fam = curfamily();
    if (fam->numfc > 0) {
      Uint sJ = wt[maskJ] + nJ;
      Uint snotJ = wt[masknotJ] + fam->numfc - nJ;
      for (IV v = 0; v < fam->lastfounderbit; v++) {
        Uint w = wt[r | v];
        Uint wJ = wt[v & maskJ];
        Uint wnotJ = wt[v & masknotJ];
        Double h = F[r | v]*prod*thp[sJ + snotJ]*thf[sJ - 2*wJ + w];
        sum += h*(count + sJ - wJ + wnotJ + wt[v & fam->mask]);
        C += h;
      }
    }
    else {
      for (IV v = 0; v <= fam->mask; v++) {
        Uint w = wt[r | v];
        Double h = F[r | v]*prod*thf[w];
        sum += h*(count + wt[v]);
        C += h;
      }
    }
  }
}

void Recombdist::addtosum_sexlinked(Double& sum, Double& C, Double prod,
                                    Double count, IV r, Person *f) {
  if (f->founder()) {
    Uint s = wt[f->mask] + 1;
    for (IV v = 0; v <= f->mask; v += f->lastbit) {
      if (f->sex == FEMALE) {
        Uint w = wt[v & f->mask];
        Double b = thf[s - 2*w];
        Double d = (w + b*(s - w))/(1.0 + b);
        addtosum_sexlinked(sum, C, prod*(1.0 + b), d + count, r | v, f->next);
      }
      else
        addtosum_sexlinked(sum, C, prod, count, r | v, f->next);
    }
  }
  else {
    for (IV v = 0; v <= curfamily()->mask; v++) {
      Uint w = wt[r | v];
      Double h = F[r | v]*prod*thf[w];
      sum += h*(count + wt[v]);
      C += h;
    }
  }
}

void Recombdist::addtosum_sexspecific(Double& sum_male, Double& sum_female,
                                      Double& C, Double prod, Double count_male,
                                      Double count_female, IV r, Person *f) {
  Uint w;
  if (f->founder()) {
    Double b;
    Uint s = wt[f->mask] + 1;
    for (IV v = 0; v <= f->mask; v += f->lastbit) {
      w = wt[v & f->mask];
      if (f->sex == MALE)
        b = thf[s - 2*w];
      else
        b = thf_female[s - 2*w];
      Double d = (w + b*(s - w))/(1.0 + b);
      Double c_m = count_male;
      Double c_f = count_female;
      if (f->sex == MALE)
        c_m += d;
      else
        c_f += d;
      addtosum_sexspecific(sum_male, sum_female, C, prod*(1.0 + b), c_m, c_f,
                           r | v, f->next);
    }
  }
  else {
    Double h;
    Family *fam = curfamily();
    assertinternal(fam->numfc == 0);
    for (IV v = 0; v <= fam->mask; v++) {
      IV u = v | r;
      h = F[r | v]*prod*thf[wt[u & mask_male]]*thf_female[wt[u & mask_female]];
      sum_male += h*(count_male + wt[v & mask_male]);
      sum_female += h*(count_female + wt[v & mask_female]);
      C += h;
    }
  }
}

void Recombdist::fcaddtosum(Double& sum, Double& C, Foundercouple *fc, BoolVec J,
                            Uint iJ, IV maskJ, IV masknotJ, Uint nJ, Uint gam) {
  if (fc != 0) {
    J[iJ] = false;
    fcaddtosum(sum, C, fc->next, J, iJ + 1, maskJ,
               masknotJ | fc->mask, nJ, gam);
    J[iJ] = true;
    fcaddtosum(sum, C, fc->next, J, iJ + 1, maskJ | fc->mask,
               masknotJ, nJ + 1, gam);
  }
  else {
    Family *fam = curfamily();
    copyvec(F, rqhat, fam->numiv);
    if (sum > 0) calcFpi(J);
    elemprod(F, lqhat, F, fam->numiv);
    fft(F, F, fam->numiv, numbits);
    normal<Float>(F, fam->numiv);
    addtosum(sum, C, 1.0, 0.0, 0, fam->first, maskJ, masknotJ, nJ);
  }
}

void Recombdist::createfj(Plist *c, Foundercouple *fc, BoolVec J,
                          Uint iJ, IV r, IV pir) {
  if (J[iJ]) {
    if (c != 0) {
      if (!c->p->removepatbit) {
        IV patbit = c->p->patmask;
        IV matbit = c->p->matmask;
        // 00
        createfj(c->next, fc, J, iJ, r, pir);
        // 11
        createfj(c->next, fc, J, iJ, r | patbit | matbit,
                    pir | patbit | matbit);
        // 01 and 10
        createfj(c->next, fc, J, iJ, r | matbit, pir | patbit);
        if ((r & fc->childrenmask) != (pir & fc->childrenmask))
          createfj(c->next, fc, J, iJ, r | patbit, pir | matbit);
      }
      else
        createfj(c->next, fc, J, iJ, r, pir);
    }
    else if (fc->next == 0)
      createfjrest(fc->lastcbit(), r, pir);
    else
      createfj(fc->next->husband->children, fc->next, J, iJ + 1, r, pir);
  }
  else {
    if (fc->next == 0)
      for (IV v = 0; v <= fc->childrenmask; v += fc->lastcbit())
        createfjrest(fc->lastcbit(), r | v, pir | v);
    else 
      for (IV v = 0; v <= fc->childrenmask; v += fc->lastcbit())
        createfj(fc->next->husband->children, fc->next, J, iJ + 1,
                 r | v, pir | v);
  }
}

void Recombdist::createfjrest(IV lastfcbit, IV r, IV pir) {
  if (r != pir) {
    for (IV v = 0; v < lastfcbit; v++) {
      Float tmp = F[r | v];
      F[r | v] = F[pir | v];
      F[pir | v] = tmp;
    }
  }
}

void Recombdist::calcFpi(BoolVec J) {
  Uint iJ = 0;
  Family *fam = curfamily();
  for (Foundercouple *fc = fam->firstfoundercouple; fc != 0;
       fc = fc->next, iJ++)
    if (J[iJ]) {
      Uint n = wt[fc->wife->mask];
      // First loop over bits of husband and wife
      for (IV w = 0; w <= fc->wife->mask; w += fc->wife->lastbit)
        for (IV h = (w >> n) + fc->husband->lastbit;
             h <= fc->husband->mask; h += fc->husband->lastbit) {
          IV m = h | w;
          IV pim = fc->pi(m);
          // Then loop over bits before fc-children bits
          for (IV v = 0; v < fam->numiv; v += fc->childrenmask + fc->lastcbit())
            // Finally loop over bits after fc-children bits
            for (IV u = v; u < fc->lastcbit() + v; u++) {
              Float tmp = F[u | m];
              F[u | m] = F[u | pim];
              F[u | pim] = tmp;
            }
        }
    }
}

void Recombdist::set(FloatVec Fp, FloatVec rqh, FloatVec lqh, Uint gam,
                     Float theta, Float theta_female, const string &pt) {
  assertinternal(pt == "rpt");
  for (Uint d = 0; d < distributions.size(); d++)
    if (distributions[d]->getpt() == pt)
      ((Recombdist *)distributions[d])->set(Fp, rqh, lqh, gam,
                                            theta, theta_female);
}

void Recombdist::set(FloatVec Fp, FloatVec rqh, FloatVec lqh,
                     Uint gam, Float theta, Float theta_female) {
  F = Fp;
  rqhat = rqh;
  lqhat = lqh;
  thf[0] = 1.0;
  thp[0] = 1.0;
  if (options->sexspecific) thf_female[0] = 1.0;
  for (Uint i = 1; i < 32; i++) {
    if (theta > 0) {
      thf[i] = thf[i - 1]*theta/(1.0 - theta);
      thf[-i] = 1.0/thf[i];
      thp[i] = thp[i - 1]*(1.0 - theta);
      thp[-i] = 1.0/thp[i];
    }
    else {
      thf[i] = thf[-i] = 0.0;
      thp[i] = thp[-i] = 1.0;
    }
    if (options->sexspecific) {
      if (theta_female > 0) {
        thf_female[i] = thf_female[i - 1]*theta_female/(1.0 - theta_female);
        thf_female[-i] = 1.0/thf_female[i];
      }
      else
        thf_female[i] = thf_female[-i] = 0.0;
    }
  }
  Double sum = 0.0;
  Double C = 0.0;
  Double sum_female = 0.0;
  Family *fam = curfamily();
  if (fam->numfc == 0) {
    elemprod(F, rqhat, lqhat, fam->numiv);
    fft(F, F, fam->numiv, numbits);
    for (IV v = 0; v < fam->numiv; v++)
      if (F[v] < 0) F[v] = -F[v]/1000.0;
    if (options->sexlinked) {
//      assertinternal(false);
      addtosum_sexlinked(sum, C, 1.0, 0.0, 0, fam->first);
    }
    else if (options->sexspecific) {
      mask_male = mask_female = 0;
      for (Person *p = fam->first; p != 0; p = p->next)
        if (p->sex == MALE)
          for (Plist *c = p->children; c != 0; c = c->next)
            mask_male |= (p->sex == MALE ? c->p->patmask : c->p->matmask);
        else
          for (Plist *c = p->children; c != 0; c = c->next)
            mask_female |= (p->sex == MALE ? c->p->patmask : c->p->matmask);
      
      addtosum_sexspecific(sum, sum_female, C, 1.0, 0.0, 0.0, 0, fam->first);
    }
    else
      addtosum(sum, C, 1.0, 0.0, 0, fam->first, 0, 0, 0);
  }
  else {
    BoolVec J;
    NEWVEC(bool, J, fam->numfc);
    copyval(J, false, fam->numfc);
    fcaddtosum(sum, C, fam->firstfoundercouple, J, 0, 0, 0, 0, gam);
    DELETEVEC(J);
  }
  Familydata *famdat = familydata.back();
  if (options->sexlinked && famdat->nbits_female > 0 ||
      options->sexspecific && famdat->nbits_male > 0 ||
      !options->sexlinked && !options->sexspecific && famdat->nbits > 0)
    famdat->recomb[gam] = (C == 0 ? 0 : sum/C);
  else
    famdat->recomb[gam] = 0;
  assertinternal(famdat->recomb[gam] >= 0);
  if (options->sexspecific) {
    famdat->recomb_female[gam] = (C == 0 && famdat->nbits_female > 0 ?
                                  0 : sum_female/C);
    assertinternal(famdat->recomb_female[gam] >= 0);
  }
  if (lastcalcmarker != -1 && lastcalcmarker - gam > 1) {
    if (options->sexspecific) {
      fillin(map->theta[1], map->markerpos[1], famdat->recomb,
             famdat->nbits_male + famdat->numsc_male, gam, lastcalcmarker);
      fillin(map->theta[2], map->markerpos[2], famdat->recomb_female,
             famdat->nbits_female + famdat->numsc_female, gam, lastcalcmarker);
    }
    else
      fillin(map->theta[0], map->markerpos[0], famdat->recomb, numbits,
             gam, lastcalcmarker);
  }
  lastcalcmarker = gam;
}

void Recombdist::nextfam(Uint pos, DoubleVec /*p0*/) {
  if (pos == 0) {
    familydata.push_back(new Familydata(npos));
    Familydata *famdat = familydata.back();
    Family *fam = curfamily();
    numbits = fam->numbits;

    // Count founders with single child
    famdat->numsc_male = famdat->numsc_female = 0;
    for (Person *p = fam->first; p != fam->firstdescendant; p = p->next)
      if (p->children != 0 && p->children->next == 0) {
        if (p->sex == MALE) famdat->numsc_male++;
        else famdat->numsc_female++;
      }
    famdat->numsc = famdat->numsc_male + famdat->numsc_female;

    // Count meioses
    famdat->nbits_male = fam->numnf;
    famdat->nbits_female = fam->numnf;
    assertinternal(famdat->nbits_male >= famdat->numsc_male &&
                   famdat->nbits_female >= famdat->numsc_female);
    famdat->nbits_male -= famdat->numsc_male;
    famdat->nbits_female -= famdat->numsc_female;
    
    famdat->nbits = famdat->nbits_male + famdat->nbits_female;
  }
  lastcalcmarker = -1;
}

void Recombdist::fillin(const vector<Float> &theta,
                        const vector<Float> &markerpos,
                        FloatVec rec, Uint n, Uint ml, Uint mr) {
  // Fills in uinformative markers between ml and mr which are both
  // informative.
  for (Uint k = ml + 1; k < mr; k++) {
    Double th1 = theta[k - 1];
    Double th2 = recombfraccent(markerpos[mr] - markerpos[k]);
    Double r = rec[k - 1];
    Double dblr = (n - r)*th1*th2/(th1*th2 + (1 - th1)*(1 - th2));
    Double sden = th1*(1 - th2) + th2*(1 - th1);
    rec[k - 1] = r*th1*(1 - th2)/sden + dblr;
    rec[k] = r*th2*(1 - th1)/sden + dblr;
  }
}

void Recombdist::calcxovers() {
  for (Uint ifa = 0; ifa < familydata.size(); ifa++) {
    Familydata *famdat = familydata[ifa];
    for (Uint k = 0; k < npos - 1; k++) {
      if (options->sexlinked) 
        famdat->xover[k] = recombtoxovers(map->theta[0][k], famdat->recomb[k],
                                          famdat->nbits_female,
                                          famdat->numsc_female);
      else if (options->sexspecific) {
        famdat->xover[k] = recombtoxovers(map->theta[1][k], famdat->recomb[k],
                                          famdat->nbits_male,
                                          famdat->numsc_male);
        famdat->xover_female[k] = recombtoxovers(map->theta[2][k],
                                                 famdat->recomb_female[k],
                                                 famdat->nbits_female,
                                                 famdat->numsc_female);
      }
      else
        famdat->xover[k] = recombtoxovers(map->theta[0][k], famdat->recomb[k],
                                          famdat->nbits, famdat->numsc);
    }
  }
}

Float Recombdist::recombtoxovers(Float theta, Float rec, Uint nb, Uint nsc) {
  Double d = fabs(centimorgan(theta)/100);
  Double expd = exp(d);
  Double expmd = 1.0/expd;
  Double tanhd = (expd - expmd)/(expd + expmd);
  if (rec < 0)
    return centimorgan(theta)*nb/100.0;
  else {
    Uint baseline = nb + nsc;
    Float Ex = (rec > 0 ? rec*d/tanhd + (baseline - rec)*d*tanhd : 0);
    return Ex - d*nsc;
  }
}

void Recombdist::skipfam() {
  if (!familydata.empty()) {
    delete familydata.back();
    familydata.pop_back();
  }
}
