#include "files.h"
#include "model.h"
#include "assocmodel.h"
#include "assocdist.h"
#include "family.h"
#include "fmtout.h"
#include "options.h"
#include "calcassoc.h"
#include "fishersexact.h"

///////////////////////////////////////////////////////////////////////////
// Assocmodel

Assocmodel::Assocmodel(const string &p, Double wa, Double wu, bool estfrq,
                       const string &patfile, const string &ctrlfile,
                       const string &of, const string &fof) :
    Model(p), zscore_raw(0), fzscore_raw(0), zscore_family(0),
    fzscore_family(0), zscore_linkage(0), fzscore_linkage(0)
/*, kurtosis(0)*/ {
  setfiles(of, fof, string("assoc.") + (estfrq ? "est." : "dat.") + pt,
           "association test output");
  if (wa == 0 && wu == 0) aw = new Calcassocfrac(estfrq);
  else aw = new Calcassocsimple(estfrq, wa, wu);
  distribution = Assocdist::getassocdist(pt, Assocdist::RAW, aw);
  distfamily = Assocdist::getassocdist(pt, Assocdist::FAMILY, aw);
  distlinkage = Assocdist::getassocdist(pt, Assocdist::LINKAGE, aw);
}

Assocmodel::Assocmodel(const string &p, const string &wf, bool estfrq,
                       const string &patfile, const string &ctrlfile,
                       const string &of, const string &fof) :
    Model(p), zscore_raw(0), fzscore_raw(0), zscore_family(0),
    fzscore_family(0), zscore_linkage(0), fzscore_linkage(0)
  /*, kurtosis(0)*/ {
  setfiles(of, fof, string("assoc.") + (estfrq ? "est." : "dat.") + pt,
           "association test output");
  aw = new Calcassocfile(estfrq, wf);
  distribution = Assocdist::getassocdist(pt, Assocdist::RAW, aw);
  distfamily = Assocdist::getassocdist(pt, Assocdist::FAMILY, aw);
  distlinkage = Assocdist::getassocdist(pt, Assocdist::LINKAGE, aw);
}

Assocmodel::~Assocmodel() {
  delete aw;
}


void Assocmodel::print() const {
  string awdesc = aw->describe();
  if (awdesc != "") awdesc = awdesc + " ";
  message("MODEL " + pt + " assoc " + awdesc + " " +
          outfile.name + " " + foutfile.name);
}

void Assocmodel::initialize(DoubleMat &zscore, DoubleMat *&fzscore) {
  NEWVEC(DoubleVec, zscore, map->num);
  for (Uint gam = 0; gam < map->num; gam++) {
    NEWVEC(Double, zscore[gam], map->numallele[gam]);
    zero(zscore[gam], map->numallele[gam]);
  }
  fzscore = new DoubleMat[nfam()];
  for (Uint fam = 0; fam < nfam(); fam++) {
    NEWVEC(DoubleVec, fzscore[fam], map->num);
    for (Uint gam = 0; gam < map->num; gam++)
      NEWVEC(Double, fzscore[fam][gam], map->numallele[gam]);
  }
}

void Assocmodel::cleanup(DoubleMat &zscore, DoubleMat *&fzscore) {
  for (Uint gam = 0; gam < map->num; gam++)
    DELETEVEC(zscore[gam]);
  DELETEVEC(zscore);
  for (Uint fam = 0; fam < nfam(); fam++) {
    for (Uint gam = 0; gam < map->num; gam++)
      DELETEVEC(fzscore[fam][gam]);
    DELETEVEC(fzscore[fam]);
  }
  delete [] fzscore;
}

Double Assocmodel::calcT(Uint ifam, Uint gam, int allele) const {
  Family *fam = distribution->families[ifam];
  Double T = 0.0;
  for (Person *p = fam->first; p != 0; p = p->next) {
    Double w = aw->assign(p, gam);
    if (w != 0) {
      if (p->gen[0][gam] == allele) T += w;
      if (p->gen[1][gam] == allele) T += w;
    }
  }
  return T;
}

void Assocmodel::getresults(Assocdist *ad,
                            DoubleMat zscore, DoubleMat *fzscore) {
  DoubleMat var;
  NEWVEC(DoubleVec, var, map->num);
  for (Uint gam = 0; gam < map->num; gam++) {
    NEWVEC(Double, var[gam], map->numallele[gam]);
    zero(var[gam], map->numallele[gam]);
  }
  for (Uint ifam = 0; ifam < nfam(); ifam++) {
    DoubleMat ET = ad->getET(ifam);
    DoubleMat ET2 = ad->getET2(ifam);
    for (Uint gam = 0; gam < map->num; gam++)
      for (Uint a = 0; a < map->numallele[gam]; a++) {
        Double T = calcT(ifam, gam, a + 1);
        Double mu = ET[gam][a];
        Double famvar = ET2[gam][a] - mu*mu;
        if (famvar < fabs(1e-10*mu)) fzscore[ifam][gam][a] = 0;
        else {
          var[gam][a] += famvar;
          fzscore[ifam][gam][a] = T - mu;
          zscore[gam][a] += fzscore[ifam][gam][a];
        }
      }
  }
  for (Uint ifam = 0; ifam < nfam(); ifam++)
    for (Uint gam = 0; gam < map->num; gam++)
      for (Uint a = 0; a < map->numallele[gam]; a++)
        if (var[gam][a] > 0)
          fzscore[ifam][gam][a] /= sqrt(var[gam][a]);
  for (Uint gam = 0; gam < map->num; gam++)
    for (Uint a = 0; a < map->numallele[gam]; a++)
      if (var[gam][a] > 0)
        zscore[gam][a] /= sqrt(var[gam][a]);
  for (Uint gam = 0; gam < map->num; gam++)
    DELETEVEC(var[gam]);
  DELETEVEC(var);
}

void Assocmodel::output() {
  initialize(zscore_raw, fzscore_raw);
  initialize(zscore_family, fzscore_family);
  initialize(zscore_linkage, fzscore_linkage);
  getresults((Assocdist *)distribution, zscore_raw, fzscore_raw);
  getresults(distfamily, zscore_family, fzscore_family);
  getresults(distlinkage, zscore_linkage, fzscore_linkage);
  Output::output();
  cleanup(zscore_raw, fzscore_raw);
  cleanup(zscore_family, fzscore_family);
  cleanup(zscore_linkage, fzscore_linkage);
}

void Assocmodel::lines(ostream &f, Uint ifa, bool famflag) {
  for (Uint gam = 0; gam < map->num; gam++) {
    if (pt == "spt" || map->shouldfindp[gam]) {
      for (Uint allele = 0; allele < map->numallele[gam]; allele++) {
        if (famflag) printfamid(f, getfamid(ifa));
        f.setf(ios::right, ios::adjustfield);
        fmtout(f, 7, 3, map->markerpos[0][gam], 0);
        if (options->datfile.size() > 0) fmtout(f, 8, allele + 1);
        else fmtout(f, 8, map->origalleles[gam][allele]);
        if (famflag) famline(f, ifa, gam, allele);
        else totline(f, gam, allele);
        f << map->markername[gam] << "\n";
      }
    }
  }
}

void Assocmodel::totline(ostream &f, Uint gam, Uint allele) {
  Uint affA = aw->getaffA(gam, allele);
  Uint affa = aw->getaffa(gam, allele);
  Uint contA = aw->getcontA(gam, allele);
  Uint conta = aw->getconta(gam, allele);
  Double z_F = Fishersexact::test(affA, affa, contA, conta);

  Double z_raw = zscore_raw[gam][allele];
  Double z_fam = zscore_family[gam][allele];
  Double z_lin = zscore_linkage[gam][allele];

  fmtout(f, 9, 4, z_raw);
  fmtout(f, 9, 4, z_F);
  fmtout(f, 9, 4, z_fam);
  fmtout(f, 9, 4, z_fam*z_F/z_raw);
  fmtout(f, 9, 4, z_lin);
  fmtout(f, 9, 4, z_lin*z_F/z_raw);
  fmtout(f, 6, affA);
  fmtout(f, 6, affa);
  fmtout(f, 6, contA);
  fmtout(f, 6, conta);
  f << "  ";
}

void Assocmodel::famline(ostream &f, Uint ifa, Uint gam, Uint allele) {
  fmtout(f, 9, 4, fzscore_raw[ifa][gam][allele]);
  fmtout(f, 9, 4, fzscore_family[ifa][gam][allele]);
  fmtout(f, 9, 4, fzscore_linkage[ifa][gam][allele]);
  f << "  ";
}

void Assocmodel::totheader(ostream &f) {
  // location allele z_nolnk z_lnk marker
  f << " allele   z_raw      z_F    z_fam    z*_fam   z_lnk    z*_lnk aff_A aff_a cnt_A cnt_a  ";
}

void Assocmodel::famheader(ostream &f) {
  // Family location allele z_nolnk z_lnk marker
  f << " allele   z_raw     z_fam    z_lnk  ";
}
