#include "probability.h"
#include "files.h"
#include "distribution.h"
#include "options.h"
#include <iostream>
#include <fstream>
#include <iomanip>
#include "matrix.h"
#include "map.h"
#include "recombdist.h"
#include "haldane.h"
#include "fmtout.h"
#include "family.h"

void Probability::check(const string &s, int gam, FloatVec d, Float /*tht*/,
                        Float *L) {
  static int checkcount = 10;
  checkcount++;
  if (options->checkcondition > 0 &&
      (checkcount % options->checkcondition) == 0) {
    Float sm = 1/normal<Float>(d, numiv);
    if (s == "lq") {
      l_nc[gam] *= sm;
    }
    if (L != 0) *L += -log(sm);
  }
//  for (IV v = 0; v < numiv; v++)
//    if (d[v] < 0) d[v] = 0;
//    if (tht < .0004) {
//      for (IV v = 0; v < numiv; v++)
//        if (d[v] < 0) d[v] = -d[v]/10000.0;
//    }
#ifdef DEBUG
  for (unsigned int v = 0; v < numiv; v++)
    assertinternal(finite(d[v]) && !isnan(d[v]));
  
  Float sum = 0.0;
  for (Uint i = 0; i < numiv; i++) sum += d[i];
  clog.setf(ios::scientific);
  clog << "check-" << s << "[" << gam << "]: " << sum << endl;
  clog.flush();
#endif // DEBUG
}

// Recursive function to calculate pi(x) for a single founder couple
void Probability::calculatefctrans(FloatVec D, FloatVec S, Float tht) {
  for (Foundercouple *fc = fam->firstfoundercouple; fc != 0; fc = fc->next) {
    Uint n = wt[fc->wife->mask];
    // First loop over bits of husband and wife
    for (IV w = 0; w <= fc->wife->mask; w += fc->wife->lastbit)
      for (IV h = (w >> n); h <= fc->husband->mask; h += fc->husband->lastbit)
        fcinnerprod(D, S, tht, h | w, fc->pi(h | w), fc);
  }
}

// Calculate founder couple transition for a founder couple
void Probability::fcinnerprod(FloatVec D, FloatVec S, Float tht, IV a, IV b,
                              Foundercouple *fc) {
  Float fac = tht/(1.0 - tht);
  Float sfac, wfac;
  IV inc1, inc2, to2;
  if (wt[fc->wife->mask] == 0) {
    inc1 = 1;
    inc2 = to2 = fam->numiv;
  }
  else {
    inc1 = fc->wife->mask + fc->wife->lastbit;
    inc2 = 1;
    to2 = fc->lastcbit();
  }
  for (IV v = 0; v < fam->numiv; v += inc1) {
    for (IV w = v; w < v + to2; w += inc2) {
      if (wt[w & fc->mask] % 2) {
        sfac = -fac;
        wfac = (1.0 - 2.0*tht)/(1.0 - tht);
      }
      else {
        sfac = fac;
        wfac = 1/(1.0 - tht);
      }
      if (fc == fam->firstfoundercouple) {
        if (a == b) D[a | w] = S[a | w]*factors[wtstar[a | w]]*wfac;
        else {
          D[a | w] = (S[a | w] + sfac*S[b | w])*factors[wtstar[a | w]];
          D[b | w] = (S[b | w] + sfac*S[a | w])*factors[wtstar[b | w]];
        }
      }
      else {
        if (a == b) D[a | w] = D[a | w]*wfac;
        else {
          Float tmp = D[a | w];
          D[a | w] = tmp + sfac*D[b | w];
          D[b | w] = D[b | w] + sfac*tmp;
        }
      }
    }
  }
}

void Probability::Ttrans(FloatVec D, FloatVec S, Float tht,
                         Float tht_male, Float tht_female) {
  // Note D == S was allowed, but not used, earlier
  assertinternal(!options->foundercouples || D != S);
  
  if (options->sexspecific) {
    assertinternal(fam->numfc == 0);
    assertinternal(tht_female >= 0);
    Float factor = 1.0 - 2.0*tht_male;
    Float factor_female = 1.0 - 2.0*tht_female;
    factors[0] = 1.0;
    factors_female[0] = 1.0;
    for (Uint i = 1; i <= 2*numbits; i++) {
      factors[i] = factors[i - 1]*factor;
      factors_female[i] = factors_female[i - 1]*factor_female;
    }
    for (IV v = 0; v < numiv; v++)
      D[v] = S[v]*factors[wtstar[v]]*factors_female[wtstar_female[v]];
  }
  else {
    Float factor = 1.0 - 2.0*tht;
//  factors[0] = 1.0/Float(numiv); // to effect renormalization in inverse Fourier transform
    factors[0] = 1.0;
    for (Uint i = 1; i <= 2*numbits; i++) 
      factors[i] = factors[i - 1]*factor;

    if (fam->numfc == 0)
      for (IV v = 0; v < numiv; v++) 
        D[v] = S[v]*factors[wtstar[v]];
    else 
      calculatefctrans(D, S, tht);
  }
}

void Probability::outputvec(ostream& str, int gam, FloatVec x, Uint N) {
  str.setf(ios::scientific);
  str.precision(12);
  if (!options->montecarlo) normal<Float>(x, N);
  for (Uint k = 0; k < N; k++) {
    str << gam << "\t" << k << "\t" << x[k] << "\n";
  }
}

void Probability::findlq(FloatVec lqhatlast, FloatVec lqhat, FloatVec lq,
                         FloatVec q, Uint lastinformative, Uint gam, Float &L) {
  // hat(l[gam-1] . q[gam-1])
  Ttrans(lq, lqhatlast, theta[lastinformative],
         theta[lastinformative], theta_female[lastinformative]);
  // hat(l[gam-1] . q[gam-1]) . hat(T[gam-1])
  fft(lq, lq, numiv, numbits); // l[gam]
  if (options->montecarlo) l_nc[gam] = sum<Float>(lq, numiv);
  if (options->lfile && gam <= options->maxlocusforoutput)
    outputvec(options->lfile, gam, lq, numiv);
  elemprod(lq, lq, q, numiv);  // l[gam] . q[gam]
  // Check if lq and q are incompatable because of numerical instability
  bool allzeroes = true;
  for (IV v = 0; v < numiv && allzeroes; v++)
    if (lq[v] != 0) allzeroes = false;
  assertcond(!allzeroes, "The distance between markers " +
             map->markername[gam] + " and " + map->markername[gam + 1] +
             " is only " + Floattostring(theta[lastinformative], 12) +
             ", and is causing numerical inconsistencies");
//     if (theta[lastinformative] < .0004) L += log(normal<Float>(lq, numiv));
  check("lq", gam, lq, (gam > 0 ? theta[gam - 1] : .5), &L);
  fft(lqhat, lq, numiv, numbits); // hat(l[gam] . q[gam])
}

void Probability::mrktopos(FloatVec posdist, FloatVec mrkhat,
                           Uint gam, Uint pos) {
  Ttrans(posdist, mrkhat,
         recombfraccent(map->markerpos[0][gam] - map->position[0][pos]),
         recombfraccent(map->markerpos[1][gam] - map->position[1][pos]),
         recombfraccent(map->markerpos[2][gam] - map->position[2][pos])); 
  fft(posdist, posdist, numiv, numbits);
}

void Probability::finddpt(Uint pos, FloatVec r) {
  // The next informative marker strictly to the left of this one
  int slm = (pos > 0 ? leftmarker[pos - 1] : NOLEFTMARKER);
  FloatVec l = q->getrow(map->num - 1, false);
  FloatVec d = l;
  if (slm == NOLEFTMARKER && r == 0)
    copyval(d, 1.0/Double(numiv), numiv);
  else {
    if (slm == NOLEFTMARKER)
      l = r;
    else {
      mrktopos(l, lqhat->getrow(slm, true), slm, pos);
      if (r != 0) elemprod(d, l, r, numiv);
    }
    normal<Float>(d, numiv);
  }
  Distribution::set(q->getrow(map->num - 1, false),
                    map->leftmarker[pos], "dpt");
}

void Probability::findp(Uint pos, FloatVec p) {
  int lm = leftmarker[pos];
  int rm = rightmarker[pos];
  if (map->inbetween[pos] || (!informative[map->leftmarker[pos]] &&
                              map->shouldfindp[map->leftmarker[pos]])) {
    if (lm == NOLEFTMARKER) {
      mrktopos(p, rqhat, rm, pos);
      if (options->pseudoautosomal) {
        FloatVec p0 = q->getrow(0, false);
        fam->pseudonull(p0);
        fft(p0, p0, numiv, numbits);
        Ttrans(p0, p0, recombfraccent(map->position[0][pos]),
               recombfraccent(map->position[1][pos]),
               recombfraccent(map->position[2][pos]));
        fft(p0, p0, numiv, numbits);
        assertinternal(normal<Float>(p0, numiv) > 0);
        elemprod(p, p, p0, numiv);
      }
    }
    else if (rm == NORIGHTMARKER) 
      mrktopos(p, lqhat->getrow(lm, true), lm, pos);
    else { // Inbetween markers or at an uninformative marker
      mrktopos(p, lqhat->getrow(lm, true), lm, pos);
      mrktopos(q->getrow(map->num - 1, false), rqhat, rm, pos);
      elemprod(p, p, q->getrow(map->num - 1, false), numiv);
    }
    normal<Float>(p, numiv);
    if (!map->inbetween[pos] && options->calcdpt)
      Distribution::set(p, map->leftmarker[pos], "dpt");
  }
  else if (informative[map->leftmarker[pos]]) {
    if (rm != NORIGHTMARKER && options->calcrpt)
      Recombdist::set(p, rqhat, lqhat->getrow(lm, true), lm, theta[lm],
                      (options->sexspecific ? theta_female[lm] : -1), "rpt");
    if (rm == NORIGHTMARKER) {
      startright(lm);
      copyvec(p, lq->getrow(lm, true), numiv);
      if (options->calcdpt) finddpt(pos, 0);
    }
    else {
                                      // rqhat starts as rqhat(k + 1)
      Ttrans(p, rqhat, theta[lm], theta[lm], theta_female[lm]);
                                      // p contains rqhat(k + 1) T(k)hat
      fft(rqhat, p, numiv, numbits);  // rqhat contains r(k)
      check("r", lm, rqhat, theta[lm]);
      if (options->rfile != 0 && lm <= int(options->maxlocusforoutput))
        outputvec(options->rfile, lm, rqhat, numiv);
      if (map->shouldfindp[lm])
        elemprod(p, lq->getrow(lm, true), rqhat, numiv); // p contains p(k)
      if (options->calcdpt) finddpt(pos, rqhat);
      elemprod(rqhat, rqhat, q->getrow(lm, true), numiv);// rqhat contains rq(k)
      fft(rqhat, rqhat, numiv, numbits);  // rqhat contains rqhat(k)
    }
    normal<Float>(p, numiv);
  }
#ifdef DEBUG
  check("p", pos, p);
  for (IV v = 0; v < numiv; v++) assertinternal(p[v] >= 0);
#endif
  if (options->pfile && pos <= options->maxlocusforoutput)
    outputvec(options->pfile, pos, p, fam->numiv);
}

void Probability::startright(int gam) {
  fft(rqhat, q->getrow(gam, true), numiv, numbits);
  check("r", gam, rqhat);
}

void Probability::lqfromlqhat(int gam, FloatVec lqp) {
  fft(lqp, lqhat->getrow(gam, true), numiv, numbits);
  scalvec(lqp, 1.0/float(numiv), lqp, numiv);
}
