package lab;

import java.io.IOException;
import java.io.PrintWriter;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;

import javax.servlet.ServletException;
import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

/**
 * Back-end lab application for CVE-2026-24880 (Tomcat chunk-extension request smuggling).
 *
 * Routes (all served from the ROOT context "/"):
 *
 *   GET /public/nonce
 *       Returns the current per-boot nonce in plain text. This is the ONLY sanctioned,
 *       value-agnostic way for a client to learn the nonce. Reachable as an ordinary
 *       (outer) request through the front-end proxy.
 *
 *   GET /internal/arrival?nonce=XXXX
 *       The smuggle target. Appends "ARRIVAL <nonce>\n" to the back-end arrival log
 *       (NONCE_DIR/arrivals.log). The front-end proxy is configured to REFUSE forwarding
 *       any outer request whose target begins with /internal, so the only way to reach
 *       this endpoint is to smuggle it inside a chunk extension that a vulnerable
 *       ChunkedInputFilter absorbs and re-parses as a second request.
 *
 *   POST /public/ingest  (or any other /public/* path)
 *       The benign outer request the exploit aims at. Drains the request body (which is
 *       where the smuggled inner request rides, inside a chunk extension) and returns 200.
 *
 * The servlet never reads the smuggled bytes itself; the smuggling is purely a Tomcat
 * connector-level effect. The arrival log is written only by the /internal/arrival route,
 * i.e. only when a second request is genuinely parsed and dispatched by Tomcat.
 */
@WebServlet(urlPatterns = {"/public/*", "/internal/*"})
public class LabServlet extends HttpServlet {

    private static Path nonceDir() {
        String d = System.getenv("NONCE_DIR");
        if (d == null || d.isEmpty()) {
            d = "/nonce";
        }
        return Paths.get(d);
    }

    private static String currentNonce() throws IOException {
        Path p = nonceDir().resolve("boot_nonce");
        byte[] b = Files.readAllBytes(p);
        return new String(b, StandardCharsets.UTF_8).trim();
    }

    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp)
            throws ServletException, IOException {
        String uri = req.getRequestURI();
        if (uri.startsWith("/public/nonce")) {
            resp.setStatus(200);
            resp.setContentType("text/plain; charset=utf-8");
            PrintWriter w = resp.getWriter();
            w.print(currentNonce());
            w.flush();
            return;
        }
        if (uri.startsWith("/internal/arrival")) {
            String nonce = req.getParameter("nonce");
            if (nonce == null) {
                nonce = "(none)";
            }
            Path log = nonceDir().resolve("arrivals.log");
            String line = "ARRIVAL " + nonce + "\n";
            synchronized (LabServlet.class) {
                Files.write(log, line.getBytes(StandardCharsets.UTF_8),
                        StandardOpenOption.CREATE, StandardOpenOption.APPEND);
            }
            resp.setStatus(200);
            resp.setContentType("text/plain; charset=utf-8");
            resp.getWriter().print("recorded\n");
            return;
        }
        resp.setStatus(404);
        resp.getWriter().print("not found\n");
    }

    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp)
            throws ServletException, IOException {
        // Benign outer endpoint. Drain whatever body arrives (the chunked body that
        // carries the smuggled request rides here) and acknowledge.
        java.io.InputStream in = req.getInputStream();
        byte[] buf = new byte[4096];
        while (in.read(buf) != -1) {
            // discard
        }
        resp.setStatus(200);
        resp.setContentType("text/plain; charset=utf-8");
        resp.getWriter().print("outer-ok\n");
    }
}
