#include "files.h"
#include "distribution.h"
#include "options.h"
#include "map.h"
#include "calc.h"
#include "family.h"
#include "haldane.h"
#include "vecutil.h"
#include "probability.h"
#include "haplodist.h"
#include "montecarlo.h"

////////////////////////////////////////////////////////////////////////
// Distribution

Distribution::Distributionvector Distribution::distributions;
bool Distribution::nullconstant = true;
Map const *Distribution::map = 0;

Distribution::Distribution(const string &p) : pt(p) {
  distributions.push_back(this);
}

Distribution::~Distribution() {}

const string &Distribution::getpt() {
  return pt;
}

void Distribution::reset(const Map &mp) {
  map = &mp;
  Uint nposspt = map->num;
  Uint nposmpt = map->numposition;
  for (Uint d = 0; d < distributions.size(); d++) {
    if (distributions[d]->getpt() == "spt" ||
        distributions[d]->getpt() == "hpt" ||
        distributions[d]->getpt() == "fpt" ||
        distributions[d]->getpt() == "cpt")
      distributions[d]->reset(nposspt);  
    else if (distributions[d]->getpt() == "mpt" ||
             distributions[d]->getpt() == "dpt" ||
             distributions[d]->getpt() == "rpt")
      distributions[d]->reset(nposmpt);
    else assertinternal(false);
    distributions[d]->families.clear();
  }
}

void Distribution::set(FloatVec pv, Uint pos, const string &p) {
  assertinternal(p == "spt" || p == "mpt" || p == "dpt" || p == "hpt" ||
                 p == "fpt" || p == "cpt");
  for (Uint d = 0; d < distributions.size(); d++)
    if (!distributions[d]->skipcurrentfamily && distributions[d]->getpt() == p)
      distributions[d]->set(pv, pos);
}

void Distribution::nextfam(Probability &prob, Family *fam, DoubleVec p0) {
  Calc::calculate(fam);
  Haplodist::setgraph(prob.graph);
  Montecarlo::setprob(prob);
  if (isnullconstant()) {
    for (Uint d = 0; d < distributions.size(); d++)
      if (distributions[d]->usefamily(fam)) {
        distributions[d]->skipcurrentfamily = false;
        distributions[d]->families.push_back(fam);
        distributions[d]->nextfam();
      }
      else
        distributions[d]->skipcurrentfamily = true;
  }
  else {
    // Construct p0
    fam->pseudonull(p0);
    
    // Evolve p0
    Uint pos = 0;
    Uint gam = 0;
    Float lastposition = 0;
    Float lastposition_female = 0;
    while (pos < map->numposition && gam < map->num) {
      bool dompt = false;
      bool dospt = false;
      Float theta = 0;
      Float theta_female = 0;
      if (fabs(map->position[0][pos] - map->markerpos[0][gam]) < 1e-10) {
        dompt = dospt = true;
        if (options->sexspecific) {
          theta = recombfraccent(map->position[1][pos] - lastposition);
          theta_female = recombfraccent(map->position[2][pos] -
                                        lastposition_female);
          lastposition = map->position[1][pos];
          lastposition_female = map->position[2][pos];
        }
        else {
          theta = recombfraccent(map->position[0][pos] - lastposition);
          lastposition = map->position[0][pos];
        }
      }
      else if (map->position[0][pos] < map->markerpos[0][gam]) {
        dompt = true;
        if (options->sexspecific) {
          theta = recombfraccent(map->position[1][pos] - lastposition);
          theta_female = recombfraccent(map->position[2][pos] - lastposition);
          lastposition = map->position[1][pos];
          lastposition_female = map->position[2][pos];
        }
        else {
          theta = recombfraccent(map->position[0][pos] - lastposition);
          lastposition = map->position[0][pos];
        }
      }
      else {
        dospt = true;
        if (options->sexspecific) {
          theta = recombfraccent(map->markerpos[1][gam] - lastposition);
          theta_female = recombfraccent(map->markerpos[2][gam] - lastposition);
          lastposition = map->markerpos[1][gam];
          lastposition_female = map->markerpos[2][gam];
        }
        else {
          theta = recombfraccent(map->markerpos[0][gam] - lastposition);
          lastposition = map->markerpos[0][gam];
        }
      }
      if (theta > 0) {
        // Convolve p0
        fft(p0, p0, fam->numiv, fam->numbits);    // hat(p0_i)
        prob.Ttrans(p0, p0, theta, theta, theta_female); // hat(p0_i)*hat(T)_i
        fft(p0, p0, fam->numiv, fam->numbits);    // p0_(i+1)
      }
      // Pass p0 down
      normal<Double>(p0, fam->numiv);
      for (Uint d = 0; d < distributions.size(); d++) {
        if (distributions[d]->usefamily(fam)) {
          distributions[d]->families.push_back(fam);
          if (dospt && distributions[d]->pt == "spt")
            distributions[d]->nextfam(gam, p0);
          else if (dompt && distributions[d]->pt == "mpt")
            distributions[d]->nextfam(pos, p0);
        }
      }
      if (dompt) pos++;
      if (dospt) gam++;
    }
  }
}

void Distribution::curfamuninformative() {
  for (Uint d = 0; d < distributions.size(); d++)
    if (!distributions[d]->skipcurrentfamily &&
        !distributions[d]->useuninformative()) {
      distributions[d]->families.pop_back();
      distributions[d]->skipfam();
      distributions[d]->skipcurrentfamily = true;
    }
}

void Distribution::setnames(bool readprobfiles, bool writeprobfiles) {
  for (Uint d = 0; d < distributions.size(); d++) {
    if (!distributions[d]->probfilein.assigned()) {
      if (readprobfiles) {
        distributions[d]->probfilein.setname("prob" +
                                             distributions[d]->describe() +
                                             '.' + distributions[d]->pt);
        distributions[d]->probfilein.optcheck(distributions[d]->pt + "probfile");
      }
      if (writeprobfiles) {
        distributions[d]->probfileout.setname("prob" +
                                              distributions[d]->describe() +
                                              '.' + distributions[d]->pt);
        distributions[d]->probfileout.optcheck(distributions[d]->pt + "probfile");
      }
    }
    if (!distributions[d]->nullfilein.assigned()) {
      if (readprobfiles) {
        distributions[d]->nullfilein.setname("null" +
                                             distributions[d]->describe() +
                                             ".dat");
        distributions[d]->nullfilein.optcheck("nullfile");
      }
      if (writeprobfiles) {
        distributions[d]->nullfileout.setname("null" +
                                              distributions[d]->describe() +
                                              ".dat");
        distributions[d]->nullfileout.optcheck("nullfile");
      }
    }
  }
}

void Distribution::setprobfilename(const string &/*pts*/, const string &distid,
                                   const string &filename) {
  for (Uint d = 0; d < distributions.size(); d++)
    if (distid == distributions[d]->describe()) {
      distributions[d]->probfilein.setname(filename);
      distributions[d]->probfileout.setname(filename);
      distributions[d]->probfilein.optcheck(distributions[d]->pt + " " +
                                            distid + " probfile");
      distributions[d]->probfileout.optcheck(distributions[d]->pt + " " +
                                             distid + " probfile");
    }
}

void Distribution::setnullfilename(const string &distid,
                                   const string &filename) {
  for (Uint d = 0; d < distributions.size(); d++)
    if (distid == distributions[d]->describe()) {
      distributions[d]->nullfilein.setname(filename);
      distributions[d]->nullfilein.optcheck(distid + " nullfile");
      distributions[d]->nullfileout.setname(filename);
      distributions[d]->nullfileout.optcheck(distid + " nullfile");
    }
}

bool Distribution::probfilesexist() {
  bool filesexist = true;
  for (Uint d = 0; d < distributions.size() && filesexist; d++)
    filesexist &= distributions[d]->nullfilein.exists();
  for (Uint d = 0; d < distributions.size() && filesexist; d++)
    filesexist &= distributions[d]->probfilein.exists();
  return filesexist;
}

void Distribution::printprobfileswritten() {
  message("Writing probfiles and nullprobfiles:");
  for (Uint d = 0; d < distributions.size(); d++)
    message("NULLFILENAME " + distributions[d]->describe() + " " +
            distributions[d]->nullfileout.name);
  for (Uint d = 0; d < distributions.size(); d++) {
    message("PROBFILENAME " + distributions[d]->pt + " " +
            distributions[d]->describe() + " " +
            distributions[d]->probfileout.name);
  }
}

void Distribution::printprobfilesread() {
  message("Reading probfiles and nullprobfiles:");
  for (Uint d = 0; d < distributions.size(); d++)
    message("NULLFILENAME " + distributions[d]->describe() + " " +
            distributions[d]->nullfilein.name);
  for (Uint d = 0; d < distributions.size(); d++)
    message("PROBFILENAME " + distributions[d]->pt + " " +
            distributions[d]->describe() + " " +
            distributions[d]->probfilein.name);
}

void Distribution::writeprobfiles() {
  for (Uint d = 0; d < distributions.size(); d++)
    distributions[d]->writeprob(*map);
}

void Distribution::writenullfiles() {
  for (Uint d = 0; d < distributions.size(); d++)
    distributions[d]->writenull();
}

const string &Distribution::getfamid(Uint ifam) const {
  assertinternal(ifam < families.size());
  return families[ifam]->id;
}

Distribution *Distribution::finddistribution(const string &desc,
                                             const string &pt) {
  for (Uint d = 0; d < distributions.size(); d++)
    if (desc == distributions[d]->describe() &&
        pt == distributions[d]->getpt())
      return distributions[d];
  return 0;
}

bool Distribution::usefamily(Family *fam) const {
  return fam->numbits > 0;
}

void Distribution::cleanup() {
  for (Uint d = 0; d < distributions.size(); d++)
    delete distributions[d];
  distributions.clear();
}
