#include <iostream>
#include <fstream>
#include <iomanip>
#include "files.h"
#include "recomb.h"
#include "matrix.h"
#include "options.h"
#include "haldane.h"
#include "fmtout.h"

Recomb::Recomb(const string &of, const string &fof) :
    Output("spt"),
    lqhat(0), F(0) {
  setfiles(of, fof, "xover.dat", "crossoverrate output");
  NEWVEC(Float, thfactor, 63);
  NEWVEC(Float, thpower, 63);
  thf = thfactor + 31;
  thp = thpower + 31;
  NEWVEC(Float, thfactor_female, 63);
  thf_female = thfactor_female + 31;
}

Recomb::~Recomb() {
  reset();
}

void Recomb::reset() {
  while (!familydata.empty()) {
    delete familydata.back();
    familydata.pop_back();
  }
}

void Recomb::addtosum(Double& sum, Double& C, Double prod, Double count, IV r,
                      Person *f, IV maskJ, IV masknotJ, Uint nJ) {
  Double b;
  Uint w;
  if (f->founder()) {
    Uint s = wt[f->mask] + 1;
    for (IV v = 0; v <= f->mask; v += f->lastbit) {
      w = wt[v & f->mask];
      b = thf[s - 2*w];
      Double d = (w + b*(s - w))/(1.0 + b);
      addtosum(sum, C, prod*(1.0 + b), d + count,
               r | v, f->next, maskJ, masknotJ, nJ);
    }
  }
  else {
    Double h;
    if (fam->numfc > 0) {
      Uint sJ = wt[maskJ] + nJ;
      Uint snotJ = wt[masknotJ] + fam->numfc - nJ;
      for (IV v = 0; v < fam->lastfounderbit; v++) {
        w = wt[r | v];
        Uint wJ = wt[v & maskJ];
        Uint wnotJ = wt[v & masknotJ];
        h = F[r | v]*prod*thp[sJ + snotJ]*thf[sJ - 2*wJ + w];
        sum += h*(count + sJ - wJ + wnotJ + wt[v & fam->mask]);
        C += h;
      }
    }
    else {
      for (IV v = 0; v <= fam->mask; v++) {
        w = wt[r | v];
        h = F[r | v]*prod*thf[w];
        sum += h*(count + wt[v]);
        C += h;
      }
    }
  }
}

void Recomb::addtosum_sexlinked(Double& sum, Double& C, Double prod,
                                Double count, IV r, Person *f) {
  Uint w;
  if (f->founder()) {
    Double b;
    Uint s = wt[f->mask] + 1;
    for (IV v = 0; v <= f->mask; v += f->lastbit) {
      if (f->sex == FEMALE) {
        w = wt[v & f->mask];
        b = thf[s - 2*w];
        Double d = (w + b*(s - w))/(1.0 + b);
        addtosum_sexlinked(sum, C, prod*(1.0 + b), d + count, r | v, f->next);
      }
      else
        addtosum_sexlinked(sum, C, prod, count, r | v, f->next);
    }
  }
  else {
    for (IV v = 0; v <= fam->mask; v++) {
      w = wt[r | v];
      Double h = F[r | v]*prod*thf[w];
      sum += h*(count + wt[v]);
      C += h;
    }
  }
}

void Recomb::addtosum_sexspecific(Double& sum_male, Double& sum_female,
                                  Double& C, Double prod, Double count_male,
                                  Double count_female, IV r, Person *f) {
  Uint w;
  if (f->founder()) {
    Double b;
    Uint s = wt[f->mask] + 1;
    for (IV v = 0; v <= f->mask; v += f->lastbit) {
      w = wt[v & f->mask];
      if (f->sex == MALE)
        b = thf[s - 2*w];
      else
        b = thf_female[s - 2*w];
      Double d = (w + b*(s - w))/(1.0 + b);
      Double c_m = count_male;
      Double c_f = count_female;
      if (f->sex == MALE)
        c_m += d;
      else
        c_f += d;
      addtosum_sexspecific(sum_male, sum_female, C, prod*(1.0 + b), c_m, c_f,
                           r | v, f->next);
    }
  }
  else {
    Double h;
    assertinternal(fam->numfc == 0);
    for (IV v = 0; v <= fam->mask; v++) {
      IV u = v | r;
      h = F[r | v]*prod*thf[wt[u & mask_male]]*thf_female[wt[u & mask_female]];
      sum_male += h*(count_male + wt[v & mask_male]);
      sum_female += h*(count_female + wt[v & mask_female]);
      C += h;
    }
  }
}

void Recomb::fcaddtosum(Double& sum, Double& C, Foundercouple *fc, BoolVec J,
                        Uint iJ, IV maskJ, IV masknotJ, Uint nJ, Uint gam) {
  if (fc != 0) {
    J[iJ] = false;
    fcaddtosum(sum, C, fc->next, J, iJ + 1, maskJ,
               masknotJ | fc->mask, nJ, gam);
    J[iJ] = true;
    fcaddtosum(sum, C, fc->next, J, iJ + 1, maskJ | fc->mask,
               masknotJ, nJ + 1, gam);
  }
  else {
    copyvec(F, rqhat, fam->numiv);
    if (sum > 0) calcFpi(J);
    elemprod(F, lqhat->getrow(gam, true), F, fam->numiv);
    fft(F, F, fam->numiv, numbits);
    normal<Float>(F, fam->numiv);
    addtosum(sum, C, 1.0, 0.0, 0, fam->first, maskJ, masknotJ, nJ);
  }
}

void Recomb::createfj(Plist *c, Foundercouple *fc, BoolVec J,
                      Uint iJ, IV r, IV pir) {
  if (J[iJ]) {
    if (c != 0) {
      if (!c->p->removepatbit) {
        IV patbit = c->p->patmask;
        IV matbit = c->p->matmask;
        // 00
        createfj(c->next, fc, J, iJ, r, pir);
        // 11
        createfj(c->next, fc, J, iJ, r | patbit | matbit,
                    pir | patbit | matbit);
        // 01 and 10
        createfj(c->next, fc, J, iJ, r | matbit, pir | patbit);
        if ((r & fc->childrenmask) != (pir & fc->childrenmask))
          createfj(c->next, fc, J, iJ, r | patbit, pir | matbit);
      }
      else
        createfj(c->next, fc, J, iJ, r, pir);
    }
    else if (fc->next == 0)
      createfjrest(fc->lastcbit(), r, pir);
    else
      createfj(fc->next->husband->children, fc->next, J, iJ + 1, r, pir);
  }
  else {
    if (fc->next == 0)
      for (IV v = 0; v <= fc->childrenmask; v += fc->lastcbit())
        createfjrest(fc->lastcbit(), r | v, pir | v);
    else 
      for (IV v = 0; v <= fc->childrenmask; v += fc->lastcbit())
        createfj(fc->next->husband->children, fc->next, J, iJ + 1,
                 r | v, pir | v);
  }
}

void Recomb::createfjrest(IV lastfcbit, IV r, IV pir) {
  if (r != pir) {
    for (IV v = 0; v < lastfcbit; v++) {
      Float tmp = F[r | v];
      F[r | v] = F[pir | v];
      F[pir | v] = tmp;
    }
  }
}

void Recomb::calcFpi(BoolVec J) {
  Uint iJ = 0;
  for (Foundercouple *fc = fam->firstfoundercouple; fc != 0;
       fc = fc->next, iJ++)
    if (J[iJ]) {
      Uint n = wt[fc->wife->mask];
      // First loop over bits of husband and wife
      for (IV w = 0; w <= fc->wife->mask; w += fc->wife->lastbit)
        for (IV h = (w >> n) + fc->husband->lastbit;
             h <= fc->husband->mask; h += fc->husband->lastbit) {
          IV m = h | w;
          IV pim = fc->pi(m);
          // Then loop over bits before fc-children bits
          for (IV v = 0; v < fam->numiv; v += fc->childrenmask + fc->lastcbit())
            // Finally loop over bits after fc-children bits
            for (IV u = v; u < fc->lastcbit() + v; u++) {
              Float tmp = F[u | m];
              F[u | m] = F[u | pim];
              F[u | pim] = tmp;
            }
        }
    }
}

void Recomb::calc(FloatVec Fp, FloatVec rqh, Uint gam, Float theta,
                  Float theta_female) {
  F = Fp;
  rqhat = rqh;
  thf[0] = 1.0;
  thp[0] = 1.0;
  if (options->sexspecific) thf_female[0] = 1.0;
  for (Uint i = 1; i < 32; i++) {
    if (theta > 0) {
      thf[i] = thf[i - 1]*theta/(1.0 - theta);
      thf[-i] = 1.0/thf[i];
      thp[i] = thp[i - 1]*(1.0 - theta);
      thp[-i] = 1.0/thp[i];
    }
    else {
      thf[i] = thf[-i] = 0.0;
      thp[i] = thp[-i] = 1.0;
    }
    if (options->sexspecific) {
      if (theta_female > 0) {
        thf_female[i] = thf_female[i - 1]*theta_female/(1.0 - theta_female);
        thf_female[-i] = 1.0/thf_female[i];
      }
      else
        thf_female[i] = thf_female[-i] = 0.0;
    }
  }
  Double sum = 0.0;
  Double C = 0.0;
  Double sum_female = 0.0;
  if (fam->numfc == 0) {
    elemprod(F, rqhat, lqhat->getrow(gam, true), fam->numiv);
    fft(F, F, fam->numiv, numbits);
    for (IV v = 0; v < fam->numiv; v++)
      if (F[v] < 0) F[v] = -F[v]/1000.0;
    if (options->sexlinked)
      {assertinternal(false);}
    //      addtosum_sexlinked(sum, C, 1.0, 0.0, 0, fam->first, FEMALE, thf, 2);
    else if (options->sexspecific) {
      mask_male = mask_female = 0;
      for (Person *p = fam->first; p != 0; p = p->next)
        if (p->sex == MALE)
          for (Plist *c = p->children; c != 0; c = c->next)
            mask_male |= (p->sex == MALE ? c->p->patmask : c->p->matmask);
        else
          for (Plist *c = p->children; c != 0; c = c->next)
            mask_female |= (p->sex == MALE ? c->p->patmask : c->p->matmask);
      
      addtosum_sexspecific(sum, sum_female, C, 1.0, 0.0, 0.0, 0, fam->first);
    }
    else
      addtosum(sum, C, 1.0, 0.0, 0, fam->first, 0, 0, 0);
  }
  else {
    BoolVec J;
    NEWVEC(bool, J, fam->numfc);
    copyval(J, false, fam->numfc);
    fcaddtosum(sum, C, fam->firstfoundercouple, J, 0, 0, 0, 0, gam);
    DELETEVEC(J);
  }
  Familydata *famdat = familydata.back();
  if (options->sexlinked && famdat->nbits_female > 0 ||
      options->sexspecific && famdat->nbits_male > 0 ||
      !options->sexlinked && !options->sexspecific && famdat->nbits > 0)
    famdat->recomb[gam] = (C == 0 ? 0 : sum/C);
  else
    famdat->recomb[gam] = 0;
  assertinternal(famdat->recomb[gam] >= 0);
  if (options->sexspecific) {
    famdat->recomb_female[gam] = (C == 0 && famdat->nbits_female > 0 ?
                                  0 : sum_female/C);
    assertinternal(famdat->recomb_female[gam] >= 0);
  }
  if (lastcalcmarker != -1 && lastcalcmarker - gam > 1) {
    if (options->sexspecific) {
      fillin(map->theta[1], map->markerpos[1], famdat->recomb,
             famdat->nbits_male + famdat->numsc_male, gam, lastcalcmarker);
      fillin(map->theta[2], map->markerpos[2], famdat->recomb_female,
             famdat->nbits_female + famdat->numsc_female, gam, lastcalcmarker);
    }
    else
      fillin(map->theta[0], map->markerpos[0], famdat->recomb, numbits,
             gam, lastcalcmarker);
  }
  lastcalcmarker = gam;
}

void Recomb::finishfam() {
//  fillin(-1, lastcalcmarker);
  assertinternal(lqhat != 0);
  calcxovers();
}

void Recomb::nextfam(Family *f) {
  assertinternal(lqhat != 0);
  fam = f;
  if (fam != 0) {
    familydata.push_back(new Familydata(f, npos()));
    Familydata *famdat = familydata.back();
    numbits = fam->numbits;

    // Count founders with single child
    famdat->numsc_male = famdat->numsc_female = 0;
    for (Person *p = fam->first; p != fam->firstdescendant; p = p->next)
      if (p->children != 0 && p->children->next == 0) {
        if (p->sex == MALE) famdat->numsc_male++;
        else famdat->numsc_female++;
      }
    famdat->numsc = famdat->numsc_male + famdat->numsc_female;

    // Count meioses
    famdat->nbits_male = fam->numnf;
    famdat->nbits_female = fam->numnf;
    assertinternal(famdat->nbits_male >= famdat->numsc_male &&
                   famdat->nbits_female >= famdat->numsc_female);
    famdat->nbits_male -= famdat->numsc_male;
    famdat->nbits_female -= famdat->numsc_female;
    
    famdat->nbits = famdat->nbits_male + famdat->nbits_female;
  }
  lastcalcmarker = -1;
}

void Recomb::fillin(const vector<Float> &theta, const vector<Float> &markerpos,
                    FloatVec rec, Uint n, Uint ml, Uint mr) {
  // Fills in uinformative markers between ml and mr which are both
  // informative.
  for (Uint k = ml + 1; k < mr; k++) {
    Double th1 = theta[k - 1];
    Double th2 = recombfraccent(markerpos[mr] - markerpos[k]);
    Double r = rec[k - 1];
    Double dblr = (n - r)*th1*th2/(th1*th2 + (1 - th1)*(1 - th2));
    Double sden = th1*(1 - th2) + th2*(1 - th1);
    rec[k - 1] = r*th1*(1 - th2)/sden + dblr;
    rec[k] = r*th2*(1 - th1)/sden + dblr;
  }
}

void Recomb::calcxovers() {
  for (Uint k = 0; k < npos() - 1; k++) {
    Familydata *famdat = familydata.back();
    if (options->sexlinked) 
      famdat->xover[k] = recombtoxovers(map->theta[0][k], famdat->recomb[k],
                                        famdat->nbits_female,
                                        famdat->numsc_female);
    else if (options->sexspecific) {
      famdat->xover[k] = recombtoxovers(map->theta[1][k], famdat->recomb[k],
                                        famdat->nbits_male, famdat->numsc_male);
      famdat->xover_female[k] = recombtoxovers(map->theta[2][k],
                                               famdat->recomb_female[k],
                                               famdat->nbits_female,
                                               famdat->numsc_female);
    }
    else
      famdat->xover[k] = recombtoxovers(map->theta[0][k], famdat->recomb[k],
                                        famdat->nbits, famdat->numsc);
  }
}

Float Recomb::recombtoxovers(Float theta, Float rec, Uint nb, Uint nsc) {
  Double d = fabs(centimorgan(theta)/100);
  Double expd = exp(d);
  Double expmd = 1.0/expd;
  Double tanhd = (expd - expmd)/(expd + expmd);
  if (rec < 0)
    return centimorgan(theta)*nb/100.0;
  else {
    Uint baseline = nb + nsc;
    Float Ex = (rec > 0 ? rec*d/tanhd + (baseline - rec)*d*tanhd : 0);
    return Ex - d*nsc;
  }
}

void Recomb::lines(ostream &f, Uint ifa, bool famflag) {
  totxover = totdistance = 0.0;
  bitsum_male = 0;
  bitsum_female = 0;
  if (!famflag) for (Uint i = 0; i < familydata.size(); i++) {
    bitsum_male += familydata[i]->nbits_male;
    bitsum_female += familydata[i]->nbits_female;
  }
  bitsum = bitsum_male + bitsum_female;

  for (Uint pos = 0; pos < npos() - 1; pos++) {
    if (map->shouldfindp[map->leftmarker[pos]]) {
      if (famflag) printfamid(f, getfamid(ifa));
      f.setf(ios::right, ios::adjustfield);
      fmtout(f, 7, 3, position(pos), 0);
      if (famflag) famline(f, ifa, pos);
      else totline(f, pos);
      f << markername(pos) << "\n";
    }
  }
  if (famflag) printfamid(f, getfamid(ifa));
  f.setf(ios::right, ios::adjustfield);
  Uint nb = famflag ? familydata[ifa]->nbits : bitsum;
  f << "Total: ";
  if (options->sexspecific) {
    fmtout(f, 9, 3, totdistance);
    fmtout(f, 6 + options->xoverrateprecision, options->xoverrateprecision,
           (bitsum_male > 0 ? totxover/bitsum_male*100.0 : 0));
    f << setw(7) << bitsum_male;
    fmtout(f, 9, 3, totdistance_female);
    fmtout(f, 6 + options->xoverrateprecision, options->xoverrateprecision,
           (bitsum_female > 0 ? totxover_female/bitsum_female*100.0 : 0));
    f << setw(7) << bitsum_female << "  ";    
  }
  else {
    fmtout(f, 9, 3, totdistance);
    fmtout(f, 9, 3, nb > 0 ? totxover/nb*100.0 : 0);
    fmtout(f, 9, 2, totxover);
    f << setw(7) << nb;
  }
  f << "  " << markername(map->num - 1) << "\n";
}

void Recomb::totline(ostream &f, Uint pos) {
  Double xoversum = 0.0;
  Double xoversum_female = 0.0;
  for (Uint i = 0; i < familydata.size(); i++)
    xoversum += familydata[i]->xover[pos];
  totxover += xoversum;
  if (options->sexspecific) {
    for (Uint i = 0; i < familydata.size(); i++)
      xoversum_female += familydata[i]->xover_female[pos];
    totxover_female += xoversum_female;
    totdistance += centimorgan(map->theta[1][pos]);
    totdistance_female += centimorgan(map->theta[2][pos]);
    fmtout(f, 9, 3, centimorgan(map->theta[1][pos]));
    fmtout(f, 6 + options->xoverrateprecision, options->xoverrateprecision,
           (bitsum_male > 0 ? xoversum/bitsum_male*100.0 : 0));
    f << setw(7) << bitsum_male;
    fmtout(f, 9, 3, centimorgan(map->theta[2][pos]));
    fmtout(f, 6 + options->xoverrateprecision, options->xoverrateprecision,
           (bitsum_female > 0 ? xoversum_female/bitsum_female*100.0 : 0));
    f << setw(7) << bitsum_female << "  ";    
  }
  else {
    totdistance += centimorgan(map->theta[0][pos]);
    fmtout(f, 9, 3, centimorgan(map->theta[0][pos]));
    fmtout(f, 6 + options->xoverrateprecision, options->xoverrateprecision,
           xoversum/bitsum*100.0);
    fmtout(f, 9, 2, xoversum);
    f << setw(7) << bitsum << "  ";
  }
}

void Recomb::famline(ostream &f, Uint ifa, Uint pos) {
  totxover += familydata[ifa]->xover[pos];
  totdistance += centimorgan(map->theta[0][pos]);
  f.setf(ios::right, ios::adjustfield);
  fmtout(f, 9, 3, centimorgan(map->theta[0][pos]));
  fmtout(f, 6 + options->xoverrateprecision, options->xoverrateprecision,
         (familydata[ifa]->nbits_male > 0 ? familydata[ifa]->xover[pos]/familydata[ifa]->nbits_male*100.0 : 0));
  fmtout(f, 9, 2, familydata[ifa]->xover[pos]);
  f << setw(7) << familydata[ifa]->nbits << "  ";
}

void Recomb::totheader(ostream &f) {
  f << " distance xoverrate  nxover   nmei ";
}

void Recomb::famheader(ostream &f) {
  f << " distance xoverrate nxover  nmei ";
}

void Recomb::skipcurfam() {
  if (!familydata.empty()) {
    delete familydata.back();
    familydata.pop_back();
  }
}

void Recomb::print() const {
  message("CROSSOVERRATE " + outfile.name + " " + foutfile.name);
}
