#include "files.h"
#include "family.h"
#include "haplodist.h"
#include "fmtout.h"
#include "options.h"
#include "map.h"
#include "utils.h"
#include "singlelocus.h"

extern Double runif();

Graph *Haplodist::graph = 0;

Haplodist::Haplodist(const string &haplo, const string &inher,
                     const string &ihaplo, const string &founder,
                     const string &p) :
    Viterbidist(p), colwidth(0) {
  haplotypefile.setname(haplo, "haplo.out");
  inherfile.setname(inher, "inher.out");
  ihaplotypefile.setname(ihaplo, "i" + haplotypefile.name);
  founderallelefile.setname(founder, "founder.out");
}

Haplodist *Haplodist::gethaplodist(const string &haplo, const string &inher,
                                   const string &ihaplo,
                                   const string &founder) {
  for (Uint d = 0; d < distributions.size(); d++)
    optassert(distributions[d]->getpt() != "hpt",
              "More than one HAPLOTYPE lines.");
  return new Haplodist(haplo, inher, ihaplo, founder);
}


void Haplodist::print() const {
  message("HAPLOTYPE " + haplotypefile.name + " " + ihaplotypefile.name +
          " " + founderallelefile.name + " " + inherfile.name);
}

void Haplodist::printheaders() {
  if (options->printoriginalalleles) {
    int minallele = 0, maxallele = 0;
    Uint from = options->addendmarkers ? 1 : 0;
    Uint to = options->addendmarkers ? map->num - 1 : map->num;
    for (Uint gam = from; gam < to; gam++) {
      IntVec orig = map->origalleles[gam];
      for (Uint a = 0; a < map->numallele[gam]; a++) {
        minallele = min_(minallele, orig[a]);
        maxallele = max_(maxallele, orig[a]);
      }
    }
    colwidth = max_(options->unknownrepeatstring.length(),
                    max_(1 + Uint(log10(Double(-minallele) + .1)),
                         Uint(log10(Double(maxallele) + .1)))) + 2 ;
  }
  else {
    Uint maxallelecount = 1;
    for (Uint gam = 0; gam < map->num; gam++)
      maxallelecount = max_(maxallelecount, map->numallele[gam]);
    colwidth = Uint(log10(Double(maxallelecount))) + 2;
  }
  colwidth = max_(colwidth, 3);
  Uint maxmarkernamelen = 0;
  for (Uint gam = 0; gam < map->num; gam++)
    maxmarkernamelen = max_(maxmarkernamelen, map->markername[gam].length());
  markerheader(inherfile, maxmarkernamelen);
  markerheader(haplotypefile, maxmarkernamelen);
  markerheader(ihaplotypefile, maxmarkernamelen);
  markerheader(founderallelefile, maxmarkernamelen);
}

void Haplodist::updatefirstgcbit(Foundercouple *fc, int &fb, int &tb,
                                 IV v0, IV v1, Float tht, bool pick_mle) const {
  IV mask = fc->mask | fc->childrenmask;
  if ((mask & v0) != (mask & v1) || !pick_mle) {
    Uint same = 0, diff = 0;
    countfcrecombs(fc, v0, v1, same, diff);
    if (pick_mle) {
      if (diff > same) {
        if (diff == same + 1) {
          // Need to break the tie (the opposite way to last tiebreak)
          fb = tb == 0 ? fb : 1 - fb;
          tb = 1 - tb;
        }
        else fb = 1 - fb;
      }
    }
    else {
      Float prob_rec = pow(tht, double(1 + same))*pow(1 - tht, double(diff));
      Float prob_norec = pow(tht, double(diff))*pow(1 - tht, double(1 + same));
      if (runif() <= prob_rec/(prob_rec + prob_norec)) fb = 1 - fb;
    }
  }
}

void Haplodist::setbits(Person *p, IntVec firstfounderbit,
                        Uint &j, IntMat curbits, IV v1) const {
  if (!p->founder() || options->sexlinked && p->sex == MALE ||
      firstfounderbit[j] == 0)
    for (Plist *c = p->children; c != 0; c = c->next) {
      if (p->sex == MALE)
        curbits[c->p->nmrk][0] = (v1 & c->p->patmask) ? 1 : 0;
      else curbits[c->p->nmrk][1] = (v1 & c->p->matmask) ? 1 : 0;
    }
  else {
    if (p->sex == MALE) curbits[p->children->p->nmrk][0] = 1;
    else curbits[p->children->p->nmrk][1] = 1;
    for (Plist *c = p->children->next; c != 0; c = c->next) {
      if (p->sex == MALE)
        curbits[c->p->nmrk][0] = (v1 & c->p->patmask) ? 0 : 1;
      else curbits[c->p->nmrk][1] = (v1 & c->p->matmask) ? 0 : 1;
    }
  }
  if (p->founder() && (!options->sexlinked || p->sex == FEMALE)) j++;
}

void Haplodist::updatefirstfounderbit(Person *p, int &fb, int &tb,
                                      IV v0, IV v1, Float tht,
                                      bool pick_mle) const {
  if ((v0 & p->mask) != (v1 & p->mask) || !pick_mle) {
    Uint same, diff;
    countfounderrecombs(p->mask, v0, v1, same, diff);
    if (pick_mle) {
      if (diff > same) {
        if (diff == same + 1) {
          // Need to break the tie (the opposite way to last tiebreak)
          fb = tb == 0 ? fb : 1 - fb;
          tb = 1 - tb;
        }
        else fb = 1 - fb;
      }
    }
    else {
      assertinternal(diff == 0 || tht > 0);
      Float prob_rec = pow(tht, double(1 + same))*pow(1 - tht, double(diff));
      Float prob_norec = pow(tht, double(diff))*pow(1 - tht, double(1 + same));
      if (runif() <= prob_rec/(prob_rec + prob_norec))
        fb = 1 - fb;
    }
  }
}  

string printmarker(Uint maxlen, Uint charnum, const string &marker) {
  string res = "";
  if (charnum >= maxlen - marker.length())
    res = marker.c_str()[maxlen - charnum - 1];
  return res;
}

void Haplodist::ivout(IVVec ivpath, bool pick_mle) {
  vector<IntMat> curbits(map->num);
  vector<IntMat> alleles(map->num);
  vector<IntMat> founderalleles(map->num);
  IntVec firstfounderbit;
  IntVec foundertiebreak;
  IntVec firstgcbit;
  IntVec fctiebreak;
  Family *fam = curfamily();
  if (fam->numf == 0)
    firstfounderbit = foundertiebreak = 0;
  else {
    NEWVEC(int, firstfounderbit, fam->numf);
    NEWVEC(int, foundertiebreak, fam->numf);
  }
  if (fam->numfc > 0) {
    NEWVEC(int, firstgcbit, fam->numfc);
    NEWVEC(int, fctiebreak, fam->numfc);
  }
  else
    firstgcbit = fctiebreak = 0;
  Person *p = fam->firstdescendant;
  for (Uint i = 0; i < fam->numnf; i++, p = p->next) p->nmrk = i;
  p = fam->first;
  for (Uint i = fam->numnf; i < fam->num; i++, p = p->next) p->nmrk = i;
  Foundercouple *fc = fam->firstfoundercouple;
  for (Uint ff = 0; fc != 0; fc = fc->next, ff++) {
    firstgcbit[ff] = options->haplotypefixbits || runif() < .5 ? 0 : 1;
    fctiebreak[ff] = 0;
  }
  for (Uint j = 0; j < fam->numf; j++) {
    firstfounderbit[j] = options->haplotypefixbits || runif() < .5 ? 0 : 1;
    foundertiebreak[j] = 0;
  }
  for (Uint gam = 0; gam < map->num; gam++) {
    const Float theta = gam > 0 ?  map->theta[0][gam - 1] : -1;
    const Float theta_male = gam > 0 ? map->theta[1][gam - 1] : -1;
    const Float theta_female = gam > 0 ? map->theta[2][gam - 1] : -1;

    NEWMAT(int, curbits[gam], fam->numnf, 2);
    NEWMAT(Allele, alleles[gam], fam->num, 2);
    NEWMAT(int, founderalleles[gam], fam->num, 2);
    fc = fam->firstfoundercouple;
    Uint j = 0;
    for (Uint ff = 0; fc != 0; fc = fc->next, ff++) {
      IV v0 = 0;
      if (gam > 0)
        v0 = (firstgcbit[ff] == 0 ? ivpath[gam - 1] :
              fc->pi(ivpath[gam - 1]) & ~fc->mask |
              (~ivpath[gam - 1] & fc->mask));
      IV v1 = firstgcbit[ff] == 0 ? ivpath[gam] :
        fc->pi(ivpath[gam]) & ~fc->mask | (~ivpath[gam] & fc->mask);
      if (gam > 0) {
        if (!options->haplotypefixbits)
          updatefirstgcbit(fc, firstgcbit[ff], fctiebreak[ff], v0, v1,
                           theta, pick_mle);
        v1 = firstgcbit[ff] == 0 ? ivpath[gam] :
          fc->pi(ivpath[gam]) & ~fc->mask | (~ivpath[gam] & fc->mask);
      }
      if (gam > 0 && !options->haplotypefixbits) 
        updatefirstfounderbit(fc->wife, firstfounderbit[j],
                              foundertiebreak[j], v0, v1, theta, pick_mle);
      setbits(fc->wife, firstfounderbit, j, curbits[gam], v1);
      if (gam > 0 && !options->haplotypefixbits) 
        updatefirstfounderbit(fc->husband, firstfounderbit[j],
                              foundertiebreak[j], v0, v1, theta, pick_mle);
      setbits(fc->husband, firstfounderbit, j, curbits[gam], v1);
      for (Plist *c = fc->wife->children; c != 0; c = c->next)
        if (c->p->children)
          setbits(c->p, firstfounderbit, j, curbits[gam], v1);
      if (fc->wife->children->p->sex == MALE)
        curbits[gam][fc->wife->children->p->children->p->nmrk][0] =
          firstgcbit[ff];
      else
        curbits[gam][fc->wife->children->p->children->p->nmrk][1] =
          firstgcbit[ff];
    }
    graph->haploreset();
    for (p = fam->first; p != 0; p = p->next) {
      if (p->children != 0 && p->fc == 0) {
        if (p->founder() && gam > 0 &&
            (!options->sexlinked || p->sex == FEMALE) &&
            !options->haplotypefixbits)
          updatefirstfounderbit(p, firstfounderbit[j], foundertiebreak[j],
                                ivpath[gam - 1], ivpath[gam],
                                options->sexspecific ?
                                (p->sex == MALE ? theta_male : theta_female) :
                                theta, pick_mle);
        setbits(p, firstfounderbit, j, curbits[gam], ivpath[gam]);
      }
      if (!p->founder()) {
        if (options->sexlinked && p->sex == MALE)
          p->nod[0] = p->nod[1] = p->mother->nod[curbits[gam][p->nmrk][1]];
        else {
          p->nod[0] = p->father->nod[curbits[gam][p->nmrk][0]];
          p->nod[1] = p->mother->nod[curbits[gam][p->nmrk][1]];
        }
      }
      if (p->gen[0] != 0 && p->genotyp[gam])
        assertcond(graph->addgenotype(p->nod[0], p->nod[1],
                                      p->gen[0][gam], p->gen[1][gam]),
                   "No consistent allele assignment at " +
                   map->markername[gam]);
    }
    graph->setallelefreq(map->pi[fam->populationindex][gam]);
    graph->assignalleles(pick_mle);

    for (p = fam->first; p != 0; p = p->next) {
      alleles[gam][p->nmrk][0] = p->nod[0]->allele;
      alleles[gam][p->nmrk][1] = p->nod[1]->allele;
    }
    graph->haploreset();
    Uint i = 0;
    for (p = fam->first; p != 0; p = p->next) {
      if (p->founder()) {
        p->nod[0]->allele = i++;
        p->nod[1]->allele = i++;
      }
      founderalleles[gam][p->nmrk][0] = p->nod[0]->allele;
      founderalleles[gam][p->nmrk][1] = p->nod[1]->allele;
    }
    graph->haploreset();
  }

  printresults(curbits, alleles, founderalleles);

  // Cleanup
  for (Uint gam = 0; gam < map->num; gam++) {
    DELETEMAT(curbits[gam]);
    DELETEMAT(alleles[gam]);
    DELETEMAT(founderalleles[gam]);
  }
}

void Haplodist::printresults(const vector<IntMat> &curbits,
                             vector<IntMat> &alleles, 
                             const vector<IntMat> founderalleles) {
  Family *fam = curfamily();
  // Print inheritance vector file
  outputfile(inherfile, curbits, fam->firstdescendant);
  // Print founder alleles file
  outputfile(founderallelefile, founderalleles, fam->first);
  // Print imputed haplotype file
  outputfile(ihaplotypefile, alleles, fam->first,
             options->printoriginalalleles);
  // Print haplotype file
  for (Person *p = fam->first; p != 0; p = p->next)
    for (Uint gam = 0; gam < map->num; gam++)
      if (p->genotyp == 0 || !p->genotyp[gam])
        alleles[gam][p->nmrk][0] = alleles[gam][p->nmrk][1] = ALLELEUNKNOWN;
  outputfile(haplotypefile, alleles, fam->first,
             options->printoriginalalleles);
}

void Haplodist::outputfile(ostream &f, const vector<IntMat> &x,
                           Person *firstper, bool trans) {
  f << "\n";
  for (Person *p = firstper; p != 0; p = p->next) {
    if (!options->sexlinked || p->sex == FEMALE)
      printline(f, p, x, 0, trans);
    printline(f, p, x, 1, trans);
  }
}

void Haplodist::markerheader(ostream &f, Uint maxlen) {
  for (Uint j = 0; j < maxlen; j++) {
    fmtout(f, options->maxfamidlen + 3*options->maxperidlen + 7 + colwidth,
           " ");
    for (Uint gam = 0; gam < map->num; gam++)
      fmtout(f, colwidth, printmarker(maxlen, j, map->markername[gam]));
    f << "\n";
  }
}

void Haplodist::printper(ostream &f, Person *p) {
  fmtout(f, options->maxfamidlen + 1, curfamily()->id);
  fmtout(f, options->maxperidlen + 1, p->id);
  fmtout(f, options->maxperidlen + 1,  p->father == 0 ? "0" : p->father->id);
  fmtout(f, options->maxperidlen + 1,  p->mother == 0 ? "0" : p->mother->id);
  fmtout(f, 1, p->sex + 1);
  fmtout(f, 2, p->origdstat);
  f << " ";
}

void Haplodist::printline(ostream &f, Person *p, const vector<IntMat> &x,
                          int sex, bool trans) {
  printper(f, p);
  for (Uint gam = 0; gam < map->num; gam++)
    if (trans && options->printoriginalalleles) {
      if (x[gam][p->nmrk][sex] == ALLELEUNKNOWN)
        fmtout(f, colwidth, options->unknownrepeatstring, false);
      else
        fmtout(f, colwidth, map->origalleles[gam][x[gam][p->nmrk][sex] - 1]);
    }
    else
      fmtout(f, colwidth, x[gam][p->nmrk][sex]);
  f << "\n";
}

void Haplodist::set(FloatVec pv, Uint pos) {
  Viterbidist::set(pv, pos);
  if (pos == 0) ivout(path, true);
}

void Haplodist::reset(Uint np) {
  if (inherfile.is_open()) {
    inherfile.close();
    haplotypefile.close();
    ihaplotypefile.close();
    founderallelefile.close();
  }
  inherfile.open();
  haplotypefile.open();
  ihaplotypefile.open();
  founderallelefile.open();
  printheaders();
  Viterbidist::reset(np);
}
