package me.koba1.led.audio;

public final class Fft {

    public static final class FftResult {
        public final float[] freqsHz; // n_fft/2+1
        public final float[] amps;    // abs(spec)
        public FftResult(float[] f, float[] a) { this.freqsHz = f; this.amps = a; }
    }

    public static FftResult fftHzAmpFromBuffer(float[][] stereoBuffer, int sampleRate, int nFft) {
        if (stereoBuffer.length == 0 || stereoBuffer[0].length < 1) {
            throw new IllegalArgumentException("stereoBuffer shape legyen (samples, channels)");
        }
        int channels = stereoBuffer[0].length;

        // last nFft frames (or zero-pad in front)
        float[] mono = new float[nFft];
        int n = stereoBuffer.length;
        int start = Math.max(0, n - nFft);
        int pad = nFft - (n - start);

        int outIdx = 0;
        for (int i = 0; i < pad; i++) mono[outIdx++] = 0.0f;

        for (int i = start; i < n; i++) {
            float sum = 0.0f;
            for (int c = 0; c < channels; c++) sum += stereoBuffer[i][c];
            mono[outIdx++] = sum / channels;
        }

        // DC remove
        float mean = 0f;
        for (float v : mono) mean += v;
        mean /= nFft;
        for (int i = 0; i < nFft; i++) mono[i] -= mean;

        // Hann window
        float[] w = hann(nFft);
        for (int i = 0; i < nFft; i++) mono[i] *= w[i];

        // FFT
        Complex[] spec = FFT.fftReal(mono);

        int bins = nFft / 2 + 1;
        float[] amps = new float[bins];
        for (int k = 0; k < bins; k++) {
            double re = spec[k].re;
            double im = spec[k].im;
            amps[k] = (float)Math.hypot(re, im);
        }

        // freq axis
        float[] freqs = new float[bins];
        for (int k = 0; k < bins; k++) {
            freqs[k] = (float)(k * (sampleRate / (double)nFft));
        }

        return new FftResult(freqs, amps);
    }

    private static float[] hann(int n) {
        float[] w = new float[n];
        if (n == 1) { w[0] = 1f; return w; }
        for (int i = 0; i < n; i++) {
            w[i] = (float)(0.5 - 0.5 * Math.cos(2.0 * Math.PI * i / (n - 1)));
        }
        return w;
    }

    public static final class Complex {
        public final double re, im;
        public Complex(double re, double im) { this.re = re; this.im = im; }
        public Complex plus(Complex b) { return new Complex(this.re + b.re, this.im + b.im); }
        public Complex minus(Complex b) { return new Complex(this.re - b.re, this.im - b.im); }
        public Complex times(Complex b) {
            return new Complex(this.re * b.re - this.im * b.im, this.re * b.im + this.im * b.re);
        }
    }

    public static final class FFT {
        public static Complex[] fft(Complex[] x) {
            int n = x.length;
            if ((n & (n - 1)) != 0) throw new IllegalArgumentException("n must be power of 2");
            if (n == 1) return new Complex[]{ x[0] };

            Complex[] even = new Complex[n / 2];
            Complex[] odd  = new Complex[n / 2];
            for (int k = 0; k < n / 2; k++) {
                even[k] = x[2 * k];
                odd[k]  = x[2 * k + 1];
            }

            Complex[] q = fft(even);
            Complex[] r = fft(odd);

            Complex[] y = new Complex[n];
            for (int k = 0; k < n / 2; k++) {
                double kth = -2.0 * k * Math.PI / n;
                Complex wk = new Complex(Math.cos(kth), Math.sin(kth));
                Complex t = wk.times(r[k]);
                y[k] = q[k].plus(t);
                y[k + n / 2] = q[k].minus(t);
            }
            return y;
        }

        public static Complex[] fftReal(float[] real) {
            int n = real.length;
            Complex[] x = new Complex[n];
            for (int i = 0; i < n; i++) x[i] = new Complex(real[i], 0.0);
            return fft(x);
        }
    }
}
