#!/usr/bin/env python3
"""
Minimal raw HTTP/1.1 upstream for CVE-2024-53271 reproduction.

Two modes, selected by the MODE environment variable:
  MODE=trigger  -> for every request, emit a non-101 1xx intermediate
                   response (102 Processing) followed by a final
                   200 OK with Content-Length: 0. This is the wire
                   pattern that drives Envoy's BALSA parser to call
                   MessageDone() twice.
  MODE=plain    -> for every request, emit only a final 200 OK with
                   Content-Length: 0 (no 1xx). Used for the liveness
                   probe: the proxy path must be healthy absent the
                   trigger.

The server speaks raw bytes over a TCP socket (no framework) so the
exact 1xx-then-final wire sequence is under our control, and supports
HTTP/1.1 keep-alive so many requests can ride one connection.
"""
import os
import socket
import sys
import threading

MODE = os.environ.get("MODE", "trigger").strip().lower()
PORT = int(os.environ.get("PORT", "8080"))
HOST = "0.0.0.0"  # internal to the docker bridge network only

ONExX_102 = b"HTTP/1.1 102 Processing\r\n\r\n"
FINAL_200 = b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n"


def read_one_request(conn_file):
    """Read one HTTP/1.1 request (headers up to blank line). Returns the
    raw header block bytes, or b'' on EOF. We ignore any request body
    because the probe requests are bodyless (HEAD / GET)."""
    lines = []
    while True:
        line = conn_file.readline()
        if not line:
            return b""  # peer closed
        lines.append(line)
        if line in (b"\r\n", b"\n"):
            break
    return b"".join(lines)


def handle(conn, addr):
    try:
        conn_file = conn.makefile("rb")
        while True:
            req = read_one_request(conn_file)
            if not req:
                break
            if MODE == "trigger":
                # 1xx intermediate first, then the final response.
                conn.sendall(ONExX_102)
                conn.sendall(FINAL_200)
            else:
                conn.sendall(FINAL_200)
    except (BrokenPipeError, ConnectionResetError, OSError):
        pass
    finally:
        try:
            conn.close()
        except OSError:
            pass


def main():
    srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    srv.bind((HOST, PORT))
    srv.listen(128)
    sys.stderr.write(f"upstream MODE={MODE} listening on {HOST}:{PORT}\n")
    sys.stderr.flush()
    while True:
        conn, addr = srv.accept()
        t = threading.Thread(target=handle, args=(conn, addr), daemon=True)
        t.start()


if __name__ == "__main__":
    main()
