#include "files.h"
#include "founderhaplodist.h"
#include "family.h"
#include "pairwise.h"
#include "founderperson.h"
#include "options.h"
#include "vecutil.h"

#include <assert.h>

// Foundercount
class Foundercountallele {
public:
  Foundercountallele() : numdescendants(0) {}

  void pushdescendant(vector<Uint> &cur) {
    if (numdescendants == 0) cur.push_back(1);
    else {
      for (unsigned int i = 0; i < cur.size(); i++)
        if (cur[i] == numdescendants) {
          cur[i]++;
          break;
        }
    }
    numdescendants++;
  }
  
  void popdescendant(vector<Uint> &cur) {
    if (numdescendants == 1) cur.pop_back();
    else {
      assertinternal(!cur.empty());
      for (int i = cur.size() - 1; i >= 0; i--)
        if (cur[i] == numdescendants) {
          cur[i]--;
          break;
        }
    }
    numdescendants--;
  }
  
protected:
  Uint numdescendants;

};

class Foundercountperson :
  public Founderperson<Foundercountallele, Foundercountperson> {
public:
  Foundercountperson(Person *p, Foundercountperson *first,
                     bool pat, bool mat) :
      Founderperson<Foundercountallele,
                    Foundercountperson>(p, first, pat, mat) {}

  void countdesc(IV v, vector<Uint> &cur, vector<vector<vector<Uint> > > &res,
                 vector<vector<Uint> > &num);
  void updateres(vector<Uint> &cur, vector<vector<vector<Uint> > > &res,
                 vector<vector<Uint> > &num);

};

void Foundercountperson::countdesc(IV v, vector<Uint> &cur,
                                   vector<vector<vector<Uint> > > &res,
                                   vector<vector<Uint> > &num) {
  assert(cur.size() < 100);
  
  int Kf = per->patmask && pattrans ? 1 : 0;
  int Km = per->matmask && mattrans? 1 : 0;
  
  for (int K1 = 0; K1 <= Km; K1++) {
    if (K1) v += per->matmask;
    if (mother != 0) nod[1] = mother->nod[K1];
    if (per->origdstat == AFFECTED && mattrans) nod[1]->pushdescendant(cur);
    for (int K0 = 0; K0 <= Kf; K0++) {
      if (K0) v += per->patmask;
      if (father != 0) {
        if (!options->sexlinked || per->sex == FEMALE) nod[0] = father->nod[K0];
        else nod[0] = nod[1];
      }
      if (per->origdstat == AFFECTED) {
        if (pattrans) nod[0]->pushdescendant(cur);
        if (next == 0)
          updateres(cur, res, num);
        else
          ((Foundercountperson *)next)->countdesc(v, cur, res, num);
        if (pattrans) nod[0]->popdescendant(cur);
      }
      else {
        assertinternal(next != 0);
        ((Foundercountperson *)next)->countdesc(v, cur, res, num);
      }
    }
    if (per->origdstat == AFFECTED && mattrans) nod[1]->popdescendant(cur);
    v &= ~per->patmask;
  }
}

void Foundercountperson::updateres(vector<Uint> &cur,
                                   vector<vector<vector<Uint> > > &res,
                                   vector<vector<Uint> > &num) {
  assertinternal(!cur.empty());
  unsigned int n = cur.size() - 1;
  vector<vector<Uint> > &cand = res[n];
  unsigned int i = 0;
  while (i < cand.size()) {
    if (cur == cand[i]) {
      num[n][i]++;
      return;
    }
    i++;
  }
  cand.push_back(cur);
  num[n].push_back(1);
}

/////////////////////////////////////////////////////////////////////////////
// Ordered foundercount
class Orderedfounderallele {
public:
  Orderedfounderallele() : idx(Uint(-1)) {}

  void pushdescendant(vector<Uint> &cur) {cur[idx]++;}
  void popdescendant(vector<Uint> &cur) {cur[idx]--;}

  void setidx(Uint i) {idx = i;}
  
protected:
  Uint idx;

};

class Orderedfounderperson :
  public Founderperson<Orderedfounderallele, Orderedfounderperson> {
public:
  Orderedfounderperson(Person *p, Orderedfounderperson *first,
                       bool pat, bool mat) :
      Founderperson<Orderedfounderallele,
                    Orderedfounderperson>(p, first, pat, mat) {}
  
  static void setindices(Orderedfounderperson *first) {
    Uint idx = 0;
    for (Orderedfounderperson *p = first; p != 0;
         p = (Orderedfounderperson *)p->next) {
      if (p->per->founder()) {
        p->nod[0]->setidx(idx++);
        p->nod[1]->setidx(idx++);
      }
    }
  }
  
  void countdesc(IV v, vector<Uint> &cur, vector<vector<Uint> > &res,
                 vector<Uint> &num);
  void updateres(vector<Uint> &cur, vector<vector<Uint> > &res,
                 vector<Uint> &num);
};

void Orderedfounderperson::countdesc(IV v, vector<Uint> &cur,
                                     vector<vector<Uint> > &res,
                                     vector<Uint> &num) {
  int Kf = per->patmask ? 1 : 0;
  if (Kf == 0 && per->father != 0 && !per->father->founder())
    Kf = 1;
  if (!pattrans) Kf = 0;
  
  int Km = per->matmask ? 1 : 0;
  if (Km == 0 && per->mother != 0 && !per->mother->founder())
    Km = 1;
  if (!mattrans) Km = 0;
  
  for (int K1 = 0; K1 <= Km; K1++) {
    if (K1) v += per->matmask;
    if (mother != 0) nod[1] = mother->nod[K1];
    if (per->origdstat == AFFECTED && mattrans) nod[1]->pushdescendant(cur);
    for (int K0 = 0; K0 <= Kf; K0++) {
      if (K0) v += per->patmask;
      if (father != 0) {
        if (!options->sexlinked || per->sex == FEMALE) nod[0] = father->nod[K0];
        else nod[0] = nod[1];
      }
      if (per->origdstat == AFFECTED) {
        if (pattrans) nod[0]->pushdescendant(cur);
        if (next == 0)
          updateres(cur, res, num);
        else
          ((Orderedfounderperson *)next)->countdesc(v, cur, res, num);
        if (pattrans) nod[0]->popdescendant(cur);
      }
      else {
        assertinternal(next != 0);
        ((Orderedfounderperson *)next)->countdesc(v, cur, res, num);
      }
    }
    if (per->origdstat == AFFECTED && mattrans) nod[1]->popdescendant(cur);
    v &= ~per->patmask;
  }
}

void Orderedfounderperson::updateres(vector<Uint> &cur,
                                     vector<vector<Uint> > &res,
                                     vector<Uint> &num) {
  assertinternal(res.size() == num.size());
  Uint i = 0;
  while (i < res.size()) {
    if (cur == res[i]) {
      num[i]++;
      return;
    }
    i++;
  }
  res.push_back(cur);
  num.push_back(1);
}


void Founderhaplodist::nextfam(Uint /*pos*/, DoubleVec /*p0*/) {
  // Setup family
  const Family *fam = curfamily();

  vector<Foundercountperson *> subfamilies;
  for (Person *p = fam->first; p != fam->firstdescendant; p = p->next)
    if (Foundercountperson::hasaffecteddescendants(p, false)) {
      bool found = false;
      for (Uint i = 0; i < subfamilies.size() && !found; i++)
        if (Foundercountperson::find(p, subfamilies[i])) found = true;
      if (!found) {
        subfamilies.push_back(0);
        Foundercountperson::addfounder(p, subfamilies.back());
        Foundercountperson::addleaves(subfamilies.back());
        assertinternal(subfamilies.back() != 0);
      }
    }

  for (Uint sf = 0; sf < subfamilies.size(); sf++) {
    // Print subfamily
//     fhfile << fam->id;
//     Foundercountperson::print(fhfile, subfamilies[sf]);
//     fhfile << "\n";
  
    // Perform calculations
    vector<vector<vector<Uint> > > fhcount;
    vector<vector<Uint> > num;
    vector<Uint> cur;
    fhcount.clear();
    num.clear();
    
    fhcount.resize(100);
    num.resize(100);
    
    subfamilies[sf]->countdesc(0, cur, fhcount, num);

    // Print results 
    for (unsigned int i = 0; i < fhcount.size(); i++) {
      assertinternal(num[i].size() == fhcount[i].size());
      for (unsigned int j = 0; j < fhcount[i].size(); j++) {
        fhfile << curfamily()->id << "\t" << num[i][j];
        for (unsigned int k = 0; k < fhcount[i][j].size(); k++)
          fhfile << "\t" << fhcount[i][j][k];
        fhfile << "\n";
      }
    }
    fhfile << "\n";
  }
}

Founderhaplodist *Founderhaplodist::getfounderhaplodist(const string &outfile) {
  for (Uint d = 0; d < distributions.size(); d++)
    if (distributions[d]->describe() == "Founderhaplodist")
      return (Founderhaplodist *)distributions[d];
  return new Founderhaplodist(outfile);
}

Founderhaplodist::Founderhaplodist(const string &outfile) :
    Distribution("spt") {
  string defname = "foundercount.out";
  fhfile.setname(outfile, defname);
  fhfile.optcheck("Founder count file");
  fhfile.open();
}

void Founderhaplodist::print() const {
  message("FOUNDERCOUNT " + fhfile.name);
}

void Orderedfounderhaplodist::nextfam(Uint /*pos*/, DoubleVec /*p0*/) {
  // Setup family
  const Family *fam = curfamily();
  
  vector<Orderedfounderperson *> subfamilies;
  for (Person *p = fam->first; p != fam->firstdescendant; p = p->next)
    if (Orderedfounderperson::hasaffecteddescendants(p, false)) {
      bool found = false;
      for (Uint i = 0; i < subfamilies.size() && !found; i++)
        if (Orderedfounderperson::find(p, subfamilies[i])) found = true;
      if (!found) {
        subfamilies.push_back(0);
        Orderedfounderperson::addfounder(p, subfamilies.back());
        Orderedfounderperson::addleaves(subfamilies.back());
        Orderedfounderperson::setindices(subfamilies.back());
        assertinternal(subfamilies.back() != 0);
      }
    }

  for (Uint sf = 0; sf < subfamilies.size(); sf++) {
    // Print subfamily
    fhfile << fam->id;
    Orderedfounderperson::print(fhfile, subfamilies[sf]);
    fhfile << "\n";

    // Perform calculations
    vector<vector<Uint> > fhcount;
    vector<Uint> num;
    vector<Uint> cur(2*Orderedfounderperson::numfounders(subfamilies[sf]));
    zero(cur, cur.size());
    fhcount.clear();
    num.clear();

    subfamilies[sf]->countdesc(0, cur, fhcount, num);

    for (Uint i = 0; i < cur.size(); i++)
      assertinternal(cur[i] == 0);

    // Printresults
    assertinternal(num.size() == fhcount.size());
    for (unsigned int j = 0; j < fhcount.size(); j++) {
      fhfile << curfamily()->id << "\t" << num[j];
      for (unsigned int k = 0; k < fhcount[j].size(); k++)
        fhfile << "\t" << fhcount[j][k];
      fhfile << "\n";
    }
  }
}

Orderedfounderhaplodist *
Orderedfounderhaplodist::getorderedfounderhaplodist(const string &outfile) {
  for (Uint d = 0; d < distributions.size(); d++)
    if (distributions[d]->describe() == "Orderedfounderhaplodist")
      return (Orderedfounderhaplodist *)distributions[d];
  return new Orderedfounderhaplodist(outfile);
}

Orderedfounderhaplodist::Orderedfounderhaplodist(const string &outfile) :
    Distribution("spt") {
  string defname = "orderedfoundercount.out";
  fhfile.setname(outfile, defname);
  fhfile.optcheck("Ordered founder count file");
  fhfile.open();
}

void Orderedfounderhaplodist::print() const {
  message("ORDEREDFOUNDERCOUNT " + fhfile.name);
}
