package me.koba1.led.audio;

import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.concurrent.locks.ReentrantLock;

public final class AudioCapture {

    // -------------------- Ring buffer (float frames x channels) --------------------
    public static final class AudioRingBuffer {
        private final int channels;
        private final int sizeFrames;
        private final float[] buf; // interleaved: frame*channels + c
        private int idxFrames = 0;
        private final ReentrantLock lock = new ReentrantLock();

        public AudioRingBuffer(double seconds, int sampleRate, int channels) {
            this.channels = channels;
            this.sizeFrames = Math.max(1, (int) Math.round(seconds * sampleRate));
            this.buf = new float[sizeFrames * channels];
        }

        public void push(float[][] block) {
            int frames = block.length;
            lock.lock();
            try {
                if (frames >= sizeFrames) {
                    int start = frames - sizeFrames;
                    for (int f = 0; f < sizeFrames; f++) {
                        float[] src = block[start + f];
                        int base = f * channels;
                        if (channels >= 0) System.arraycopy(src, 0, buf, base, channels);
                    }
                    idxFrames = 0;
                    return;
                }

                int end = idxFrames + frames;
                if (end <= sizeFrames) {
                    for (int f = 0; f < frames; f++) {
                        int dstFrame = idxFrames + f;
                        int base = dstFrame * channels;
                        float[] src = block[f];
                        for (int c = 0; c < channels; c++) buf[base + c] = src[c];
                    }
                } else {
                    int first = sizeFrames - idxFrames;
                    for (int f = 0; f < first; f++) {
                        int base = (idxFrames + f) * channels;
                        float[] src = block[f];
                        for (int c = 0; c < channels; c++) buf[base + c] = src[c];
                    }
                    int remain = frames - first;
                    for (int f = 0; f < remain; f++) {
                        int base = f * channels;
                        float[] src = block[first + f];
                        for (int c = 0; c < channels; c++) buf[base + c] = src[c];
                    }
                }
                idxFrames = end % sizeFrames;
            } finally {
                lock.unlock();
            }
        }

        public float[][] snapshot() {
            lock.lock();
            try {
                float[][] out = new float[sizeFrames][channels];
                int outPos = 0;

                for (int f = idxFrames; f < sizeFrames; f++) {
                    int base = f * channels;
                    for (int c = 0; c < channels; c++) out[outPos][c] = buf[base + c];
                    outPos++;
                }
                for (int f = 0; f < idxFrames; f++) {
                    int base = f * channels;
                    for (int c = 0; c < channels; c++) out[outPos][c] = buf[base + c];
                    outPos++;
                }
                return out;
            } finally {
                lock.unlock();
            }
        }

        public int sizeFrames() { return sizeFrames; }
        public int channels() { return channels; }
    }

    // -------------------- Config + errors --------------------
    public static final class AudioCaptureConfig {
        public int rate = 48000;
        public int channels = 2;
        public int framesPerChunk = 512;
        public String format = "s16le";
        public int latencyMsec = 10;
    }

    public static final class PulseError extends RuntimeException {
        public PulseError(String msg) { super(msg); }
        public PulseError(String msg, Throwable t) { super(msg, t); }
    }

    // -------------------- Helpers: pactl / parec --------------------
    private static String runCmd(List<String> cmd) {
        try {
            Process p = new ProcessBuilder(cmd).redirectErrorStream(false).start();
            String stdout = new String(p.getInputStream().readAllBytes(), StandardCharsets.UTF_8);
            String stderr = new String(p.getErrorStream().readAllBytes(), StandardCharsets.UTF_8);
            int rc = p.waitFor();
            if (rc != 0) {
                throw new PulseError(
                        "Command failed: " + String.join(" ", cmd) +
                                "\nSTDOUT:\n" + stdout + "\nSTDERR:\n" + stderr
                );
            }
            return stdout;
        } catch (IOException | InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new PulseError("Command failed: " + String.join(" ", cmd), e);
        }
    }

    public static String getDefaultSinkName() {
        String out = runCmd(List.of("pactl", "info"));
        for (String line : out.split("\n")) {
            String l = line.trim().toLowerCase(Locale.ROOT);
            if (l.startsWith("default sink:")) {
                return line.split(":", 2)[1].trim();
            }
        }
        throw new PulseError("Couldn't find Default Sink in `pactl info`");
    }

    public static String getMonitorSourceForSink(String sinkName) {
        return sinkName + ".monitor";
    }

    // -------------------- Capture thread --------------------
    public static final class PulseMonitorCapture {
        private final AudioCaptureConfig cfg;
        private final AudioRingBuffer ring;
        private volatile boolean stop = false;
        private Thread thread;
        private Process proc;

        public volatile String sourceName;
        public volatile double lastPushMonoSec = 0.0;

        public PulseMonitorCapture(AudioCaptureConfig cfg, AudioRingBuffer ring) {
            this.cfg = cfg;
            this.ring = ring;
        }

        public void start() {
            if (thread != null && thread.isAlive()) return;
            stop = false;
            thread = new Thread(this::runLoop, "PulseMonitorCapture");
            thread.setDaemon(true);
            thread.start();
        }

        public void stop() {
            stop = true;
            if (proc != null) proc.destroy();
            if (thread != null) {
                try { thread.join(2000); }
                catch (InterruptedException e) { Thread.currentThread().interrupt(); }
            }
        }

        private void runLoop() {
            String sink = getDefaultSinkName();
            this.sourceName = getMonitorSourceForSink(sink);

            int bytesPerSample = 2; // s16le
            int chunkBytes = cfg.framesPerChunk * cfg.channels * bytesPerSample;

            List<String> cmd = new ArrayList<>();
            cmd.add("parec");
            cmd.add("-d"); cmd.add(sourceName);
            cmd.add("--format=" + cfg.format);
            cmd.add("--rate=" + cfg.rate);
            cmd.add("--channels=" + cfg.channels);
            cmd.add("--latency-msec=" + cfg.latencyMsec);

            try {
                proc = new ProcessBuilder(cmd)
                        .redirectError(ProcessBuilder.Redirect.DISCARD)
                        .start();

                InputStream in = proc.getInputStream();
                byte[] readBuf = new byte[4096];
                ByteArrayOutputStream pending = new ByteArrayOutputStream(8192);

                while (!stop) {
                    int n = in.read(readBuf);
                    if (n <= 0) {
                        try { Thread.sleep(1); }
                        catch (InterruptedException e) { Thread.currentThread().interrupt(); }
                        continue;
                    }
                    pending.write(readBuf, 0, n);

                    byte[] pend = pending.toByteArray();
                    int off = 0;

                    while (pend.length - off >= chunkBytes) {
                        float[][] block = decodeS16LEToFloat(pend, off, chunkBytes, cfg.channels);
                        ring.push(block);
                        lastPushMonoSec = monotonicSeconds();
                        off += chunkBytes;
                    }

                    if (off > 0) {
                        pending.reset();
                        if (off < pend.length) pending.write(pend, off, pend.length - off);
                    }
                }
            } catch (IOException e) {
                if (!stop) throw new PulseError("parec failed", e);
            } finally {
                if (proc != null) proc.destroy();
            }
        }

        private static float[][] decodeS16LEToFloat(byte[] src, int off, int len, int channels) {
            int samples = len / 2;
            int frames = samples / channels;
            float[][] out = new float[frames][channels];

            int i = off;
            for (int f = 0; f < frames; f++) {
                for (int c = 0; c < channels; c++) {
                    int lo = src[i++] & 0xFF;
                    int hi = src[i++]; // signed
                    short s = (short)((hi << 8) | lo);
                    out[f][c] = (float)(s / 32768.0);
                }
            }
            return out;
        }

        private static double monotonicSeconds() {
            return System.nanoTime() / 1_000_000_000.0;
        }
    }
}