#include "vecutil.h"
#include "warning.h"
#include "files.h"
#include "options.h"

void fftseg(FloatVec d, Uint N, Uint nbits) {
  IV k = 1; 
  for (Uint i = 0; i < nbits; i++) {
    IV u = 0;
    while (u < N) {
      for (IV v = u; v < u + k; v++) {
        Float t = d[v];
        d[v] = t + d[v + k];
        d[v + k] = t - d[v + k];
      }
      u += k << 1;
    }
    k <<= 1;
  }
}

void fftold(FloatVec d, FloatVec s, Uint N, Uint nbits) {
// Fourier transforms s into d
  const Uint NMAX = 1024*4;
  if (d != s) copyvec(d, s, N);
  if (N >= NMAX) {
    Uint N2 = N >> 1;
    fftold(d, d, N2, nbits - 1);
    fftold(d + N2, d + N2, N2, nbits - 1);
    for (Uint v = 0; v < N2; v++) {
      Float t = d[v];
      d[v] = t + d[v + N2];
      d[v + N2] = t - d[v + N2];
    }
  }
  else
    fftseg(d, N, nbits);
}

inline void fftseg2(FloatVec d, Uint N, IV from, IV to) {
  IV k = POW2[from]; 
  for (Uint i = from; i < to; i++) {
    IV u = 0;
    IV nextk = k << 1;
    while (u < N) {
      for (IV v = u; v < u + k; v++) {
        Float t = d[v];
        d[v] = t + d[v + k];
        d[v + k] = t - d[v + k];
      }
      u += nextk;
    }
    k = nextk;
  }
}

inline void fftseg4(FloatVec d, Uint N, IV from, IV to) {
  IV k = POW2[from]; 
  for (Uint i = from; i < to; i += 2) {
    IV u = 0;
    IV k2 = k << 1;
    IV k3 = k2 + k;
    IV nextk = k << 2;
    while (u < N) {
      for (IV v = u; v < u + k; v++) {
        register Float t1, t2, t3, t4, t;
        t = d[v];
        t2 = d[v + k];
        t1 = t + t2;
        t2 = t - t2;
        t = d[v + k2];
        t4 = d[v + k3];
        t3 = t + t4;
        t4 = t - t4;
        d[v] = t1 + t3;
        d[v + k2] = t1 - t3;
        d[v + k] = t2 + t4;
        d[v + k3] = t2 - t4;
      }
      u += nextk;
    }
    k = nextk;
  }
}

inline void fftseg8(FloatVec d, Uint N, IV from, IV to) {
  IV k = POW2[from]; 
  for (Uint i = from; i < to; i += 3) {
    IV u = 0;
    IV k2 = k << 1;
    IV k3 = k2 + k;
    IV k4 = k3 + k;
    IV k5 = k4 + k;
    IV k6 = k5 + k;
    IV k7 = k6 + k;
    IV nextk = k << 3;
    while (u < N) {
      for (IV v = u; v < u + k; v++) {
        register Float t1, t2, t3, t4, t5, t6, t7, t8, t;
        t = d[v];
        t2 = d[v + k];
        t1 = t + t2;
        t2 = t - t2;
        t = d[v + k2];
        t4 = d[v + k3];
        t3 = t + t4;
        t4 = t - t4;
        t = t1;
        t1 += t3;
        t3 = t - t3;
        t = t2;
        t2 += t4;
        t4 = t - t4;
        t = d[v + k4];
        t6 = d[v + k5];
        t5 = t + t6;
        t6 = t - t6;
        t = d[v + k6];
        t8 = d[v + k7];
        t7 = t + t8;
        t8 = t - t8;
        t = t5;
        t5 += t7;
        t7 = t - t7;
        t = t6;
        t6 += t8;
        t8 = t - t8;
        d[v] = t1 + t5;
        d[v + k4] = t1 - t5;
        d[v + k] = t2 + t6;
        d[v + k5] = t2 - t6;
        d[v + k2] = t3 + t7;
        d[v + k6] = t3 - t7;
        d[v + k3] = t4 + t8;
        d[v + k7] = t4 - t8;
      }
      u += nextk;
    }
    k = nextk;
  }
}

inline void fftseg16(FloatVec d, Uint N, IV from, IV to) {
  IV k = POW2[from];
  for (Uint i = from; i < to; i += 4) {
    IV u = 0;
    IV k2 = k << 1;
    IV k3 = k2 + k;
    IV k4 = k3 + k;
    IV k5 = k4 + k;
    IV k6 = k5 + k;
    IV k7 = k6 + k;
    IV k8 = k7 + k;
    IV k9 = k8 + k;
    IV k10 = k9 + k;
    IV k11 = k10 + k;
    IV k12 = k11 + k;
    IV k13 = k12 + k;
    IV k14 = k13 + k;
    IV k15 = k14 + k;
    IV nextk = k << 4;
    while (u < N) {
      for (IV v = u; v < u + k; v++) {
        register Float t1, t2, t3, t4, t5, t6, t7, t8, t9,
          t10, t11, t12, t13, t14, t15, t16, t;
        t = d[v];
        t2 = d[v + k];
        t1 = t + t2;
        t2 = t - t2;
        t = d[v + k2];
        t4 = d[v + k3];
        t3 = t + t4;
        t4 = t - t4;
        t = d[v + k4];
        t6 = d[v + k5];
        t5 = t + t6;
        t6 = t - t6;
        t = d[v + k6];
        t8 = d[v + k7];
        t7 = t + t8;
        t8 = t - t8;
        t = d[v + k8];
        t10 = d[v + k9];
        t9 = t + t10;
        t10 = t - t10;
        t = d[v + k10];
        t12 = d[v + k11];
        t11 = t + t12;
        t12 = t - t12;
        t = d[v + k12];
        t14 = d[v + k13];
        t13 = t + t14;
        t14 = t - t14;
        t = d[v + k14];
        t16 = d[v + k15];
        t15 = t + t16;
        t16 = t - t16;

        t = t1;
        t1 += t3;
        t3 = t - t3;
        t = t2;
        t2 += t4;
        t4 = t - t4;
        t = t5;
        t5 += t7;
        t7 = t - t7;
        t = t6;
        t6 += t8;
        t8 = t - t8;
        t = t9;
        t9 += t11;
        t11 = t - t11;
        t = t10;
        t10 += t12;
        t12 = t - t12;
        t = t13;
        t13 += t15;
        t15 = t - t15;
        t = t14;
        t14 += t16;
        t16 = t - t16;
        
        t = t1;
        t1 += t5;
        t5 = t - t5;
        t = t2;
        t2 += t6;
        t6 = t - t6;
        t = t3;
        t3 += t7;
        t7 = t - t7;
        t = t4;
        t4 += t8;
        t8 = t - t8;
        t = t9;
        t9 += t13;
        t13 = t - t13;
        t = t10;
        t10 += t14;
        t14 = t - t14;
        t = t11;
        t11 += t15;
        t15 = t - t15;
        t = t12;
        t12 += t16;
        t16 = t - t16;
        
        
        d[v]       = t1 + t9;
        d[v + k8]  = t1 - t9;
        d[v + k]   = t2 + t10;
        d[v + k9]  = t2 - t10;
        d[v + k2]  = t3 + t11;
        d[v + k10] = t3 - t11;
        d[v + k3]  = t4 + t12;
        d[v + k11] = t4 - t12;
        d[v + k4]  = t5 + t13;
        d[v + k12] = t5 - t13;
        d[v + k5]  = t6 + t14;
        d[v + k13] = t6 - t14;
        d[v + k6]  = t7 + t15;
        d[v + k14] = t7 - t15;
        d[v + k7]  = t8 + t16;
        d[v + k15] = t8 - t16;
      }
      u += nextk;
    }
    k = nextk;
  }
}

typedef void(*fftsegfun)(FloatVec , Uint, IV, IV);

int order[32];
Uint norder;
Uint curbits = MAXBITS + 1;

void int2fftseg(int i, FloatVec d, Uint N, IV from, IV to) {
  switch (i) {
    case 1: fftseg2(d, N, from, to); break;
    case 2: fftseg4(d, N, from, to); break;
    case 3: fftseg8(d, N, from, to); break;
    case 4: fftseg16(d, N, from, to); break;
    default: fatal(string("Illegal number of bits to unloop (") + i + ")");
  }
}

void createorder(Uint nbits) {
  if (nbits <= options->maxunloop) {
    order[0] = nbits;
    norder = 1;
  }
  else {
    norder = 0;
    while (nbits > 2*options->maxunloop) {
      order[norder++] = options->maxunloop;
      nbits -= options->maxunloop;
     }
    if (nbits == 2*options->maxunloop) {
      order[norder++] = options->maxunloop;
      order[norder++] = options->maxunloop;
    }
    else if (nbits == 2*options->maxunloop - 1) {
      order[norder++] = options->maxunloop;
      order[norder++] = options->maxunloop - 1;
    }
    else if (nbits == 2*options->maxunloop - 2) {
      order[norder++] = options->maxunloop - 1;
      order[norder++] = options->maxunloop - 1;
    }
    else if (nbits == 2*options->maxunloop - 3) {
      order[norder++] = options->maxunloop - 1;
      order[norder++] = options->maxunloop - 2;
    }
    else assertinternal(false);
  }
}

void fft(FloatVec d, FloatVec s, Uint N, Uint nbits) {
// Fourier transforms s into d
  const Uint NMAX = 1024*4;
  const Uint MAXBITS = 12;
  if (curbits != nbits) {
    curbits = nbits;
    if (MAXBITS < nbits) createorder(nbits - MAXBITS);
    else createorder(nbits);
  }
  if (d != s) copyvec(d, s, N);
  if (MAXBITS >= nbits) {
    Uint lastbit = 0;
    for (Uint i = 0; i < norder; i++) {
      int2fftseg(order[i], d, N, lastbit, lastbit + order[i]);
      lastbit += order[i];
    }
    assertinternal(lastbit == nbits);
  }
  else {
    // First do the first MAXBITS steps
    for (Uint j = 0; j < N; j += NMAX)
      int2fftseg(options->maxunloop, d + j, NMAX, 0, MAXBITS);
    // Then do the last nbits - MAXBITS
    Uint lastbit = MAXBITS;
    for (Uint i = 0; i < norder; i++) {
      int2fftseg(order[i], d, N, lastbit, lastbit + order[i]);
      lastbit += order[i];
    }
    assertinternal(lastbit == nbits);
  }
}  
