#include "calcscore.h"
#include "findscore.h"
#include "files.h"
#include "options.h"
#include "vecutil.h"
#include "genpairs.h"
#include "kinship.h"
#include "findibd.h"

///////////////////////////////////////////////////////////////////////////
// Calcscore

void Calcscore::calcscore(Foundercouple *fc, Uint &i, Family *fam) {
  if (fc == 0) {
    calc(vec + i*fam->numiv, fam);
    i++;
  }
  else {
    calcscore(fc->next, i, fam);
    if (fc->asymmetric()) {
      fc->swapdstat();
      calcscore(fc->next, i, fam);
      fc->swapdstat();
    }
  }
}

void Calcscore::operator()(Family *fam) {
  zero(vec, fam->numiv*POW2[fam->nasymmetricfoundercouples()]);

  setupscorefamily(fam);

  // Recursive function to take care of founder couple asymmetries
  Uint i = 0;
  calcscore(fam->firstfoundercouple, i, fam);
  if (options->printpairscores)
    printpairscores(fam);
}

///////////////////////////////////////////////////////////////////////////
// Calcqualitativescore
void Calcqualitativescore::calc(DoubleVec score, Family *fam) {
  if (id == "pairs" || id == "homoz") {
    first->resetdstat();
    firstdesc->calcspairs(0, 0.0, score, id == "homoz");
  }
  else if (id == "all") {
    for (Scoreperson *p = first; p != firstdesc; p = (Scoreperson *)p->next) {
      p->nod[0]->setinitcount(0.0);
      p->nod[1]->setinitcount(0.0);
    }
    first->calcsall(0, 1.0, score);
  }
  else if (id == "robdom")
    first->calcsrobdom(0, 0.0, score);
  else if (id == "mnallele") {
    Uint numf = fam->numf;
    if (options->sexlinked) {
      numf = 0;
      for (Person *p = fam->first; p != fam->firstdescendant; p = p->next)
        numf++;
    }
    first->calcsmnallele(0, 2.0*numf, score);
  }
  else assertinternal(false);
};

void Calcqualitativescore::setupscorefamily(Family *fam) {
  first = new Scoreperson(fam->first, 0);
  Scoreperson *lastfounder = 0;
  for (Person *p = fam->first->next; p != 0; p = p->next) {
    Scoreperson *next = new Scoreperson(p, first);
    if (p->founder()) lastfounder = next;
  }
  firstdesc = (Scoreperson *)lastfounder->next;
}

Calcqualitativescore *
Calcqualitativescore::getcalcqualitativescore(const string &sc) {
  Calc *clc = findcalc(sc);
  if (clc != 0) return (Calcqualitativescore *)clc;
  else return new Calcqualitativescore(sc);
}

///////////////////////////////////////////////////////////////////////////
// Calcparentspecificscore
void Calcparentspecificscore::operator() (Family *fam) {
  Boolvector bv;
  for (Foundercouple *fc = fam->firstfoundercouple; fc != 0; fc = fc->next)
    if (fc->asymmetric()) {
      assertinternal(!fc->wife->hastraitvalue());
      fc->wife->traitvalue = 1;
      bv.push_back(true);
    } else bv.push_back(false);

  Calcscore::operator()(fam);

  unsigned int ifc = 0;
  for (Foundercouple *fc = fam->firstfoundercouple; fc != 0; fc = fc->next) {
    if (bv[ifc]) {
      assertinternal(fc->wife->traitvalue == 1);
      fc->wife->traitvalue = NOTRAITVALUE;
    }
    ifc++;
  }
}

void Calcparentspecificscore::calc(DoubleVec score, Family */*fam*/) {
  assertinternal(id == "ps");
  firstdesc->calcspairs_ps(0, 0, score, w_mm, w_mf, w_ff);
}

Calcparentspecificscore *Calcparentspecificscore::
getcalcparentspecificscore(const string &sc, Double wmm, Double wmf, Double wff) {
  Calc *clc = findcalc(description(sc, wmm, wmf, wff));
  if (clc != 0) return (Calcparentspecificscore *)clc;
  else return new Calcparentspecificscore(sc, wmm, wmf, wff);
}

///////////////////////////////////////////////////////////////////////////
// Calcgenpairs
Calcgenpairs::Calcgenpairs(const string &sc) : Calcscore(sc) {}

string Calcgenpairs::pair2string(const string &a, const string &b) const {
  return a + " " + b;
}

void Calcgenpairs::setupscorefamily(Family *fam) {
  Uint index = 0;
  first = new Genpairsperson(fam->first, 0, pair2weight1, pair2weight2, index,
                             fam->num - 1, informative[fam->first->id]);
  for (Person *p = fam->first->next; p != 0; p = p->next)
    new Genpairsperson(p, first, pair2weight1, pair2weight2, index,
                       fam->num - 1, informative[p->id]);
}

void Calcgenpairs::calc(DoubleVec score, Family *) {
  first->calcgenspairs(0, 0.0, score);  
}

void Calcgenpairs::printpairscores(Family *fam) {
  if (!pairscores.assigned()) {
    pairscores.setname(describe() + ".sc");
    assertcond(pairscores.fileok(), string("Cannot open pairscores file ") +
               pairscores.name);
    pairscores.open();
  }
  for (Person *p = fam->first; p != 0; p = p->next)
    for (Person *q = p->next; q != 0; q = q->next) {
      string key = pair2string(p->id, q->id);
      if (pair2weight1[key] != 0 || pair2weight2[key] != 0) {
        pairscores << fam->id << "\t" << p->id << "\t" << q->id << "\t";
        fmtout(pairscores, 9, 7, pair2weight1[key]);
        fmtout(pairscores, 9, 7, pair2weight2[key]);
        pairscores << "\n";
      }
    }
}

///////////////////////////////////////////////////////////////////////////
// Calcgenpairsfile
Calcgenpairsfile::Calcgenpairsfile(const string &sc, const string &fn) :
    Calcgenpairs(sc), filename(fn) {
  Infile weightfile;
  weightfile.setname(filename);
  weightfile.optcheck("GENPAIRS weightfile");
  weightfile.open();
  const Uint MAXLINE = 10000;
  char buf[MAXLINE];
  Uint linenum = 0;
  weightfile >> ws;
  while (!weightfile.eof() && !weightfile.fail()) {
    linenum++;
    weightfile >> buf;
    string per1 = buf;
    weightfile >> buf;
    string pair1 = pair2string(per1, buf);
    string pair2 = pair2string(buf, per1);
    Double weight;
    weightfile >> weight;
    pair2weight1[pair1] = pair2weight1[pair2] = weight;
    pair2weight2[pair1] = pair2weight2[pair2] = 2*weight;
    weightfile >> ws;
  }
  weightfile.close();
}

Calcgenpairsfile *Calcgenpairsfile::getcalcgenpairsfile(const string &sc,
                                                        const string &fn) {
  Calc *clc = findcalc(sc + " " + fn);
  if (clc != 0) return (Calcgenpairsfile *)clc;
  else return new Calcgenpairsfile(sc, fn);
}

void Calcgenpairsfile::operator()(Family *fam) {
  fam->calckinship();
  for (Person *p = fam->first; p != 0; p = p->next)
    informative[p->id] = false;
  for (Person *p = fam->first; p != 0; p = p->next)
    for (Person *q = p->next; q != 0; q = q->next) {
      Float ks = fam->kinship->getkinship(p->nmrk, q->nmrk);
      if (ks > 0) {
        const string key = pair2string(p->id, q->id);
        if (pair2weight1[key] != 0 || pair2weight2[key] != 0)
          informative[p->id] = informative[q->id] = true;
      }
    }
  Calcgenpairs::operator()(fam);
}

///////////////////////////////////////////////////////////////////////////
// Calcgenpairsshared
Float Calcgenpairsshared::calcweight(Float ks, Float x1, Float x2) {
  if (id == "qtlscore") {
    Float G = ks*shared;
    Float G2 = G*G;
    Float d = 1 - G2;
    return .5*G/d + (-(x1*x1 + x2*x2)*G + x1*x2*(1 + G2))/(2.0*d*d);
  }
  else if (id == "qtlhe") {
    Float d = x1 - x2;
    return -d*d + 2*(variance - ks*shared);
  }
  else if (id == "qtlnhe" || id == "qtlwpc")
    return x1*x2 - ks*shared;
  else assertinternal(false);
}

void Calcgenpairsshared::operator()(Family *fam) {
  // Find all pairs that potentially share and calculate the weights of
  // the pairs
  pair2weight1.clear();
  pair2weight2.clear();
  fam->calckinship();
  for (Person *p = fam->first; p != 0; p = p->next)
    informative[p->id] = false;
  for (Person *p = fam->first; p != 0; p = p->next)
    if (p->hastraitvalue()) {
      Float x_p = p->traitvalue;
      for (Person *q = p->next; q != 0; q = q->next) {
        if (q->hastraitvalue()) {
          Float ks = fam->kinship->getkinship(p->nmrk, q->nmrk);
          if (ks > 0) {
            Float x_q = q->traitvalue;
            const string key = pair2string(p->id, q->id);
            informative[p->id] = informative[q->id] = true;
            pair2weight1[key] = calcweight(ks, x_p, x_q);
            pair2weight2[key] = 2*pair2weight1[key];
          }
        }
      }
    }
  Calcgenpairs::operator()(fam);
}

Calcgenpairsshared *Calcgenpairsshared::getcalcgenpairsshared(const string &sc,
                                                              Double shr,
                                                              Double var) {
  Calc *clc = findcalc(description(sc, shr, var));
  if (clc != 0) return (Calcgenpairsshared *)clc;
  else return new Calcgenpairsshared(sc, shr, var);
}

///////////////////////////////////////////////////////////////////////////
// Calcgenpairsvc
FindIBD *Calcgenpairsvc::findIBD = 0;
Float Calcgenpairsvc::probpair(Float ks, Float x1, Float x2, Float pi) {
  Float rho = ks*shared + pi*sigma2_g + (pi > .95 ? 1 : 0)*sigma2_d;
  Float u = x1 - x2;
  Float v = x1 + x2;

  return exp(-.25*(u*u/(1 - rho) + v*v/(1 + rho)))/sqrt((1 - rho)*(1 + rho));
}

void Calcgenpairsvc::calcweight(Float p1, Float p2, Float x1, Float x2,
                                Float &weight1, Float &weight2) {
  Float ks = .5*p1 + p2;
  Float px0 = probpair(ks, x1, x2, 0);
  Float px1 = probpair(ks, x1, x2, .5);
  Float px2 = probpair(ks, x1, x2, 1);
  Float px = (1 - p1 - p2)*px0 + p1*px1 + p2*px2;
  if (id == "qtlgenpairsncp") {
    Float mu_0 = .5*p1 + p2;
    Float o2_0 = .25*p1 + p2 - mu_0*mu_0;
    Float mu_A = (.5*px1*p1 + px2*p2)/px;
    Float w = (mu_A - mu_0)/o2_0;
    weight1 = .5*w;
    weight2 = w;
  }
  if (id == "qtlpairs" || id == "qtlpairsncp") {
    Float w0 = px0/px;
    Float w1 = px1/px;
    Float w2 = px2/px;
    if (id == "qtlpairs") {
      weight1 = log(w1) - log(w0);
      weight2 = log(w2) - log(w0);
    }
    else if (id == "qtlpairsncp") {
      weight1 = w1 - w0;
      weight2 = w2 - w0;
    }
  }
}

void Calcgenpairsvc::operator()(Family *fam) {
  if (findIBD == 0 || fam != findIBD->curfamily()) {
    delete findIBD;
    findIBD = new FindIBD(fam, "qtl");
    findIBD->findprior(0, 0);
  }
  // Find all pairs that potentially share and calculate the weights of
  // the pairs
  const Double tolerance = 1e-6;
  pair2weight1.clear();
  pair2weight2.clear();
  for (Person *p = fam->first; p != 0; p = p->next)
    informative[p->id] = false;
  for (Person *p = fam->first; p != 0; p = p->next)
    if (p->hastraitvalue()) {
      Float x_p = p->traitvalue;
      for (Person *q = p->next; q != 0; q = q->next) {
        if (q->hastraitvalue()) {
          Float p1, p2;
          findIBD->priorresult(p, q, p1, p2);
          if (fabs(p1 + p2) > tolerance || fabs(p1 - 1) > tolerance ||
              fabs(p2 - 1) > tolerance) {
            Float x_q = q->traitvalue;
            string key = pair2string(p->id, q->id);
            informative[p->id] = informative[q->id] = true;
            calcweight(p1, p2, x_p, x_q, pair2weight1[key], pair2weight2[key]);
          }
        }
      }
    }
  Calcgenpairs::operator()(fam);
}

Calcgenpairsvc *Calcgenpairsvc::getcalcgenpairsvc(const string &sc,
                                                  Double /*s2*/, Double s,
                                                  Double a, Double d) {
  Calc *clc = findcalc(description(sc, s, a, d));
  if (clc != 0) return (Calcgenpairsvc *)clc;
  else return new Calcgenpairsvc(sc, s, a, d);
}
