#include <fstream>
#include "convolve.h"
#include "vecutil.h"
#include "utils.h"

#define TOL 1.0e-12

void swap(DoubleVec &a, DoubleVec &b) {
  DoubleVec tmp = a;
  a = b;
  b = tmp;
}

Convolver::Convolver(Uint nb, Uint nf) : nbins(nb), nfam(nf) {
  NEWVEC(Double, bin, nbins);
  NEWVEC(Double, binprob, nbins);
  NEWVEC(Double, binnew, nbins);
  NEWVEC(Double, binprobnew, nbins);
}

Convolver::~Convolver() {
  DELETEVEC(bin);
  DELETEVEC(binprob);
  DELETEVEC(binnew);
  DELETEVEC(binprobnew);
}

void Convolver::convolve(DoubleMat x, DoubleMat p, UintVec m) {
  zero(binprob, nbins);
  zero(binprobnew, nbins);
  nbinsused = 0;
  Uint nusednew;
  dobinning = true;

  Double binmax = binmin;
  for (Uint i = 0; i < nfam; i++) {
    if (!dobinning && m[i]*max_(1, nbinsused) > 50000 || dobinning) {
      Uint n;
      if (!dobinning || i == 0) {
        dobinning = true;
        binmin = vecmin<Double>(bin, nbinsused) + vecmin<Double>(x[i], m[i]);
        binmax = vecmax<Double>(bin, nbinsused) + vecmax<Double>(x[i], m[i]);
        n = nbinsused;
      }
      else {
        binmin += vecmin<Double>(x[i], m[i]);
        binmax += vecmax<Double>(x[i], m[i]);
        n = nbins;
      }
      nbinsused = nbins;
      binwidth = max_(TOL, (binmax - binmin)/(nbins - 1));
      for (Uint j = 0; j < nbins; j++)
        binnew[j] = j*binwidth + binmin;
      zero(binprobnew, nbins);
      if (n == 0)
        for (Uint k = 0; k < m[i]; k++)
          binprobnew[getbinidx(x[i][k])] += p[i][k];
      else 
        for (Uint j = 0; j < n; j++)
          if (binprob[j] > TOL) 
            for (Uint k = 0; k < m[i]; k++)
              binprobnew[getbinidx(x[i][k] + bin[j])] += p[i][k]*binprob[j];
    }
    else {
      nusednew = 0;
      if (nbinsused == 0)
        for (Uint k = 0; k < m[i]; k++)
          add(x[i][k], p[i][k], nusednew);
      else
        for (Uint j = 0; j < nbinsused; j++)
          for (Uint k = 0; k < m[i]; k++)
            add(bin[j] + x[i][k], binprob[j]*p[i][k], nusednew);
      nbinsused = nusednew;
    }
    swap(bin, binnew);
    swap(binprob, binprobnew);
  }
  Double cumsum = 0.0;
  for (Uint i = 0; i < nbinsused; i++) {
    Double tmp = binprob[i];
    binprob[i] = 1.0 - cumsum;
    cumsum += tmp;
  }
}

void Convolver::add(Double x, Double p, Uint &binsused) {
  if (dobinning) {
    assertinternal(false);
  }
  else {
    Uint j;
    for (j = 0; j < binsused && x >= binnew[j] +.001; j++);
    if (j < binsused && fabs(binnew[j] - x) < 0.001) binprobnew[j] += p;
    else {
      for (Uint i = binsused; i > j; i--) {
        binnew[i] = binnew[i - 1];
        binprobnew[i] = binprobnew[i - 1];
      }
      binsused++;
      binnew[j] = x;
      binprobnew[j] = p;
    }
  }
}

Double Convolver::calcpvalue(Double x) const {
  if (dobinning) {
    int idx = getbinidx(x);
    int a, b;
    if (idx >= nbins) return 0.0;
    if (x < bin[idx]) {
      a = idx - 1;
      b = idx;
    }
    else {
      a = idx;
      b = idx + 1;
    }
    if (a < 0) return 1.0;
    else if (b > nbins - 1) return binprob[idx];
//    else return binprob[idx];
else       return binprob[a] + (x - bin[a])/(bin[b] - bin[a])*
         (binprob[b] - binprob[a]);
  }
  else {
    Uint i;
    for (i = 0; i < nbinsused && x > bin[i] + 0.001; i++);
    assertinternal(i < nbinsused);
    return binprob[i];
  }
  assertinternal(false);
}

void Convolver::write(const string &filename) const {
  ofstream f(filename.c_str());
  assertcond(f, "Unable to open file " + filename +
             " for writing exact p-values");
  for (Uint i = 0; i < nbins; i++) f << bin[i] << "\t" << binprob[i] << "\n";
}
