#include "findnormal.h"
#include "warning.h"
#include "vecutil.h"
#include "kinship.h"
#include "family.h"

FindNormalLoglik::~FindNormalLoglik() {
  cleanup();
  DELETEVEC(traitvalues);
}

void FindNormalLoglik::calcloglik(Family *fam, DoubleVec S, Plist *pl, IV v,
                                  Uint pi, Double log_det, Double xEx,
                                  Double sigma2, Double shared,
                                  Double sigma2_g, Double sigma2_d) {
  if (pl == 0) {
//    S[v] = pow(2*M_PI, -Double(fam->nqtl)/2.0)*exp(-.5*xEx)/sqrt(det);
    S[v] = -.5*(xEx + log_det);
  }
  else {
    Person *p = pl->p;
    IV Kf = p->patmask ? 1 : 0;
    IV Km = p->matmask ? 1 : 0;
    
    for (IV K1 = 0; K1 <= Km; K1++) {
      if (K1) v += p->matmask;
      if (p->mother != 0) p->nod[1] = p->mother->nod[K1];
      for (IV K0 = 0; K0 <= Kf; K0++) {
        if (K0) v += p->patmask;
        if (p->father != 0) p->nod[0] = p->father->nod[K0];
        Uint newpi = pi;
        Double newdet = log_det;
        Double newxEx = xEx;
        if (p->hastraitvalue()) {
          newpi++;

          y[pi] = traitvalues[pi];
          Uint qi = 0;
          for (Person *q = fam->first; qi < pi; q = q->next)
            if (q->hastraitvalue()) {
              Uint oibd = 1;
              if (p->nod[0] != q->nod[0] && p->nod[0] != q->nod[1] &&
                  p->nod[1] != q->nod[0] && p->nod[1] != q->nod[1])
                oibd = 0;
              else if (p->nod[0] == q->nod[0] && p->nod[1] == q->nod[1] ||
                       p->nod[0] == q->nod[1] && p->nod[1] == q->nod[0])
                oibd = 2;
              
              L[pi][qi] = fam->kinship->getkinship(p->nmrk, q->nmrk)*shared +
                oibd*sigma2_g/2.0 + (oibd == 2 ? sigma2_d : 0);
              for (Uint k = 0; k < qi; k++) L[pi][qi] -= L[pi][k]*L[qi][k];
              L[pi][qi] /= L[qi][qi];

              y[pi] -= L[pi][qi]*y[qi];
              
              qi++;
            }

          L[pi][pi] = sigma2 + shared + sigma2_g + sigma2_d;
          for (Uint k = 0; k < pi; k++) L[pi][pi] -= L[pi][k]*L[pi][k];
          newdet += log(L[pi][pi]);
          L[pi][pi] = sqrt(L[pi][pi]);
          y[pi] /= L[pi][pi];
          newxEx += y[pi]*y[pi];
          
        }
        calcloglik(fam, S, pl->next, v, newpi, newdet, newxEx, sigma2, shared,
                   sigma2_g, sigma2_d);
      }
      v &= ~p->patmask;
    }
  }
}

bool FindNormalLoglik::hasphenotypeddesc(Person *p) const {
  bool res = false;
  for (Plist *c = p->children; c != 0 && !res; c = c->next) {
    res |= c->p->hastraitvalue();
    if (!res) res |= hasphenotypeddesc(c->p);
  }
  return res;
}

void FindNormalLoglik::extracttraitvalues(Family *fam) {
  DELETEVEC(traitvalues);
  traitvalues = new Double[fam->nqtl];
  Uint pi = 0;
  for (Person *p = fam->first; p != 0; p = p->next)
    if (p->hastraitvalue()) {
      traitvalues[pi] = p->traitvalue;
      assertinternal(pi < fam->nqtl);
      pi++;
    }
}

void FindNormalLoglik::initialize(Family *fam) {
  fam->calckinship();
  if (fam->nqtl > curnqtl) {
    cleanup();
    curnqtl = fam->nqtl;
    // Alloc ibd and E
    ibd = newsymmetricmatrix<Uint>(fam->nqtl);
    E = newsymmetricmatrix<Double>(fam->nqtl);
    L = newsymmetricmatrix<Double>(fam->nqtl);
    NEWVEC(Double, y, fam->nqtl);
  }
  extracttraitvalues(fam);
}

void FindNormalLoglik::cleanup() {
  deletematrix(ibd);
  deletematrix(E);
  deletematrix(L);
  DELETEVEC(L);
  DELETEVEC(y);
}

void FindNormalLoglik::calc(Family *fam, DoubleVec S, Double sigma2,
                            Double shared, Double sigma2_g, Double sigma2_d) {
  initialize(fam);
  // Throw out ungenotyped leaves
  Plist *firstper = 0, *lastper = 0;
  Uint uninformativemask = 0;
  for (Person *p = fam->first; p != 0; p = p->next)
    if (p->children != 0 && hasphenotypeddesc(p) || p->hastraitvalue()) {
      if (firstper == 0) lastper = firstper = new Plist(p, 0);
        else {
          lastper->next = new Plist(p, 0);
          lastper = lastper->next;
        }
    } else uninformativemask |= p->patmask | p->matmask;
  // Calculate loglik
  calcloglik(fam, S, firstper, 0, 0, 0, 0.0, sigma2, shared,
             sigma2_g, sigma2_d);
  delete firstper;
  // Fill in results for all IV's
  expandiv(S, uninformativemask, fam->numiv);
}

Double FindNormalLoglik::calc(Family *fam, DoubleVec ibd1, DoubleVec ibd2,
                              Double sigma2, Double shared, Double sigma2_g,
                              Double sigma2_d) {
  initialize(fam);
  // Calculate E(pihat)
  Uint pi = 0;
  Uint ibdindex = 0;
  for (Person *p = fam->first; p != 0; p = p->next)
    if (p->hastraitvalue()) {
      Uint qi = 0;
      for (Person *q = fam->first; pi > qi; q = q->next) {
        assertinternal(p != q);
        if (q->hastraitvalue()) {
          E[pi][qi] = (fam->kinship->getkinship(p->nmrk, q->nmrk)*shared +
                       (ibd1[ibdindex]/2.0 + ibd2[ibdindex])*sigma2_g +
                       ibd2[ibdindex]*sigma2_d);
          qi++;
          ibdindex++;
        }
      }
      E[pi][pi] = sigma2 + shared + sigma2_g + sigma2_d;
      pi++;
    }

  Double logdet = 0.0;
  Double xEx = 0.0;
  for (Uint j = 0; j < fam->nqtl; j++) {
    for (Uint i = 0; i < j; i++) {
      L[j][i] = E[j][i];
      for (Uint k = 0; k < i; k++)
        L[j][i] -= L[j][k]*L[i][k];
      L[j][i] /= L[i][i];
    }
    L[j][j] = E[j][j];
    y[j] = traitvalues[j];
    for (Uint k = 0; k < j; k++) {
      L[j][j] -= L[j][k]*L[j][k];
      y[j] -= L[j][k]*y[k];
    }
    logdet += log(L[j][j]);
    L[j][j] = sqrt(L[j][j]);
    y[j] /= L[j][j];
    xEx += y[j]*y[j];
  }
  return -.5*(logdet + xEx);
}
