package me.koba1.led.audio;

import java.time.Instant;
import java.util.*;
import java.util.concurrent.locks.ReentrantLock;

public final class AudioProcessor {

    // -------------------- Shared state (thread-safe) --------------------
    public static final class AudioState {
        public double ts;     // epoch seconds
        public float rms;
        public float peak;
        public final Map<String, Float> bands = new LinkedHashMap<>();
    }

    private static final AudioState AUDIO_STATE = new AudioState();
    private static final ReentrantLock AUDIO_LOCK = new ReentrantLock();

    public static AudioState getAudioStateCopy() {
        AUDIO_LOCK.lock();
        try {
            AudioState out = new AudioState();
            out.ts = AUDIO_STATE.ts;
            out.rms = AUDIO_STATE.rms;
            out.peak = AUDIO_STATE.peak;
            out.bands.putAll(AUDIO_STATE.bands);
            return out;
        } finally {
            AUDIO_LOCK.unlock();
        }
    }

    // -------------------- Bands --------------------
    private static final LinkedHashMap<String, int[]> BAND_RANGES = new LinkedHashMap<>();
    static {
        BAND_RANGES.put("sub_bass",  new int[]{20, 60});
        BAND_RANGES.put("bass",      new int[]{60, 120});
        BAND_RANGES.put("low_mid",   new int[]{120, 250});
        BAND_RANGES.put("mid",       new int[]{250, 500});
        BAND_RANGES.put("upper_mid", new int[]{500, 2000});
        BAND_RANGES.put("presence",  new int[]{2000, 4000});
        BAND_RANGES.put("treble",    new int[]{4000, 8000});
        BAND_RANGES.put("air",       new int[]{8000, 16000});
    }

    public static final class BandCfg {
        public float attack = 0.55f;
        public float release = 0.18f;
        public float refDecay = 0.995f;
        public float noiseFloor = 1e-4f;
        public float eps = 1e-9f;
    }

    public static final class BandState {
        public final BandCfg cfg = new BandCfg();
        public final Map<String, Float> smooth = new HashMap<>();
        public final Map<String, Float> ref = new HashMap<>();

        public BandState() {
            for (String k : BAND_RANGES.keySet()) {
                smooth.put(k, 0.0f);
                ref.put(k, 1e-3f);
            }
        }
    }

    public static Map<String, Float> bandsToDict(float[] freqsHz, float[] amps, BandState state) {
        BandCfg cfg = state.cfg;
        Map<String, Float> out = new LinkedHashMap<>();

        for (Map.Entry<String, int[]> e : BAND_RANGES.entrySet()) {
            String name = e.getKey();
            int f0 = e.getValue()[0];
            int f1 = e.getValue()[1];

            double sumSq = 0.0;
            int cnt = 0;
            for (int i = 0; i < freqsHz.length; i++) {
                float f = freqsHz[i];
                if (f >= f0 && f < f1) {
                    float a = amps[i];
                    sumSq += (double)a * (double)a;
                    cnt++;
                }
            }

            float x = 0.0f;
            if (cnt > 0) x = (float)Math.sqrt(sumSq / cnt);

            x = Math.max(0.0f, x - cfg.noiseFloor);

            float prev = state.smooth.get(name);
            float a = (x > prev) ? cfg.attack : cfg.release;
            float s = prev + a * (x - prev);
            state.smooth.put(name, s);

            float rPrev = state.ref.get(name);
            float r = Math.max(rPrev * cfg.refDecay, Math.max(s, 1e-6f));
            state.ref.put(name, r);

            float y = s / (r + cfg.eps);
            if (y < 0f) y = 0f;
            if (y > 1f) y = 1f;

            out.put(name, y);
        }

        return out;
    }

    // -------------------- Thread handle --------------------
    public static final class AudioThreadHandle {
        public final Runnable stop;
        public final Thread thread;
        private AudioThreadHandle(Runnable stop, Thread thread) { this.stop = stop; this.thread = thread; }
    }

    public static AudioThreadHandle startAudioThread() {
        AudioCapture.AudioRingBuffer ring = new AudioCapture.AudioRingBuffer(0.10, 48000, 2);

        AudioCapture.AudioCaptureConfig cfg = new AudioCapture.AudioCaptureConfig();
        cfg.rate = 48000;
        cfg.channels = 2;
        cfg.framesPerChunk = 512;

        AudioCapture.PulseMonitorCapture cap = new AudioCapture.PulseMonitorCapture(cfg, ring);
        cap.start();

        final BandState bandState = new BandState();
        final int nFft = 2048;

        final var stopFlag = new Object() { volatile boolean stop = false; };

        Thread t = new Thread(() -> {
            try {
                while (!stopFlag.stop) {
                    float[][] buf = ring.snapshot();
                    if (buf.length < nFft) continue;

                    int start = buf.length - nFft;

                    float[] mono = new float[nFft];
                    for (int i = 0; i < nFft; i++) {
                        float l = buf[start + i][0];
                        float r = buf[start + i][1];
                        mono[i] = (l + r) * 0.5f;
                    }

                    double sumSq = 0.0;
                    float peak = 0.0f;
                    for (float v : mono) {
                        sumSq += (double)v * (double)v;
                        float av = Math.abs(v);
                        if (av > peak) peak = av;
                    }
                    float rms = (float)Math.sqrt(sumSq / nFft);

                    Fft.FftResult fr = Fft.fftHzAmpFromBuffer(buf, 48000, nFft);
                    Map<String, Float> bandDict = bandsToDict(fr.freqsHz, fr.amps, bandState);

                    AUDIO_LOCK.lock();
                    try {
                        AUDIO_STATE.ts = Instant.now().toEpochMilli() / 1000.0;
                        AUDIO_STATE.rms = rms;
                        AUDIO_STATE.peak = peak;
                        AUDIO_STATE.bands.clear();
                        AUDIO_STATE.bands.putAll(bandDict);
                    } finally {
                        AUDIO_LOCK.unlock();
                    }

                    try { Thread.sleep((long)(1000.0 / 30.0)); }
                    catch (InterruptedException e) { Thread.currentThread().interrupt(); }
                }
            } finally {
                cap.stop();
            }
        }, "AudioThread");
        t.setDaemon(true);
        t.start();

        return new AudioThreadHandle(() -> stopFlag.stop = true, t);
    }

    // -------------------- demo --------------------
    public static void main(String[] args) throws Exception {
        AudioThreadHandle h = startAudioThread();
        for (int i = 0; i < 200; i++) {
            AudioState s = getAudioStateCopy();
            System.out.printf(Locale.ROOT,
                    "ts=%.3f rms=%.5f peak=%.5f bass=%.3f mid=%.3f treble=%.3f%n",
                    s.ts, s.rms, s.peak,
                    s.bands.getOrDefault("bass", 0f),
                    s.bands.getOrDefault("mid", 0f),
                    s.bands.getOrDefault("treble", 0f)
            );
            Thread.sleep(200);
        }
        h.stop.run();
        h.thread.join(2000);
    }
}