package org.cryptomator.launcher; import java.io.Closeable; import java.io.IOException; import java.net.InetAddress; import java.net.ServerSocket; import java.nio.ByteBuffer; import java.nio.channels.ReadableByteChannel; import java.nio.channels.WritableByteChannel; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.nio.file.StandardOpenOption; import java.rmi.ConnectException; import java.rmi.NotBoundException; import java.rmi.Remote; import java.rmi.RemoteException; import java.rmi.registry.LocateRegistry; import java.rmi.registry.Registry; import java.rmi.server.RMIClientSocketFactory; import java.rmi.server.RMIServerSocketFactory; import java.rmi.server.RMISocketFactory; import java.rmi.server.UnicastRemoteObject; import org.apache.commons.lang3.SystemUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.google.common.io.MoreFiles; /** * First running application on a machine opens a server socket. Further processes will connect as clients. */ abstract class InterProcessCommunicator implements InterProcessCommunicationProtocol, Closeable { private static final Logger LOG = LoggerFactory.getLogger(InterProcessCommunicator.class); private static final String RMI_NAME = "Cryptomator"; public abstract boolean isServer(); /** * @param endpoint The server-side communication endpoint. * @return Either a client or a server communicator. * @throws IOException In case of communication errors. */ public static InterProcessCommunicator start(InterProcessCommunicationProtocol endpoint) throws IOException { return start(getIpcPortPath(), endpoint); } // visible for testing static InterProcessCommunicator start(Path portFilePath, InterProcessCommunicationProtocol endpoint) throws IOException { // try to connect to existing server: int port = readPort(portFilePath); LOG.debug("Connecting to running process on TCP port {}...", port); try { ClientCommunicator client = new ClientCommunicator(port); LOG.trace("Connected to running process."); return client; } catch (ConnectException | NotBoundException e) { LOG.debug("Did not find running process."); // continue } // spawn a new server: LOG.trace("Spawning new server..."); ServerCommunicator server = new ServerCommunicator(endpoint); writePort(portFilePath, server.getPort()); LOG.debug("Server listening on port {}.", server.getPort()); return server; } private static Path getIpcPortPath() { final String settingsPathProperty = System.getProperty("cryptomator.ipcPortPath"); if (settingsPathProperty == null) { LOG.warn("System property cryptomator.ipcPortPath not set."); return Paths.get("ipcPort.tmp"); } else { return Paths.get(replaceHomeDir(settingsPathProperty)); } } private static String replaceHomeDir(String path) { if (path.startsWith("~/")) { return SystemUtils.USER_HOME + path.substring(1); } else { return path; } } public static class ClientCommunicator extends InterProcessCommunicator { private final IpcProtocolRemote remote; private ClientCommunicator(int port) throws ConnectException, NotBoundException, RemoteException { if (port == 0) { throw new ConnectException("Can not connect to port 0."); } Registry registry = LocateRegistry.getRegistry(port); this.remote = (IpcProtocolRemote) registry.lookup(RMI_NAME); } @Override public void handleLaunchArgs(String[] args) { try { remote.handleLaunchArgs(args); } catch (RemoteException e) { throw new RuntimeException(e); } } @Override public boolean isServer() { return false; } @Override public void close() { // no-op } } public static class ServerCommunicator extends InterProcessCommunicator { private final ServerSocket socket; private final Registry registry; private final IpcProtocolRemoteImpl remote; private ServerCommunicator(InterProcessCommunicationProtocol delegate) throws IOException { this.socket = new ServerSocket(0, Byte.MAX_VALUE, InetAddress.getLocalHost()); RMIClientSocketFactory csf = RMISocketFactory.getDefaultSocketFactory(); SingletonServerSocketFactory ssf = new SingletonServerSocketFactory(socket); this.registry = LocateRegistry.createRegistry(0, csf, ssf); this.remote = new IpcProtocolRemoteImpl(delegate); UnicastRemoteObject.exportObject(remote, 0); registry.rebind(RMI_NAME, remote); } @Override public void handleLaunchArgs(String[] args) { throw new UnsupportedOperationException("Server doesn't invoke methods."); } @Override public boolean isServer() { return true; } private int getPort() { return socket.getLocalPort(); } @Override public void close() { try { registry.unbind(RMI_NAME); UnicastRemoteObject.unexportObject(remote, true); socket.close(); LOG.debug("Server shut down."); } catch (NotBoundException | IOException e) { LOG.warn("Failed to close IPC Server.", e); } } } private static interface IpcProtocolRemote extends Remote { void handleLaunchArgs(String[] args) throws RemoteException; } private static class IpcProtocolRemoteImpl implements IpcProtocolRemote { private final InterProcessCommunicationProtocol delegate; protected IpcProtocolRemoteImpl(InterProcessCommunicationProtocol delegate) throws RemoteException { this.delegate = delegate; } @Override public void handleLaunchArgs(String[] args) { delegate.handleLaunchArgs(args); } } /** * Always returns the same pre-constructed server socket. */ private static class SingletonServerSocketFactory implements RMIServerSocketFactory { private final ServerSocket socket; public SingletonServerSocketFactory(ServerSocket socket) { this.socket = socket; } @Override public synchronized ServerSocket createServerSocket(int port) throws IOException { if (port != 0) { throw new IllegalArgumentException("This factory doesn't support specific ports."); } return this.socket; } } private static int readPort(Path path) throws IOException { if (Files.notExists(path)) { return 0; } ByteBuffer buf = ByteBuffer.allocate(Integer.BYTES); try (ReadableByteChannel ch = Files.newByteChannel(path, StandardOpenOption.READ)) { if (ch.read(buf) == Integer.BYTES) { buf.flip(); return buf.getInt(); } else { return 0; } } } private static void writePort(Path path, int port) throws IOException { ByteBuffer buf = ByteBuffer.allocate(Integer.BYTES); buf.putInt(port); buf.flip(); MoreFiles.createParentDirectories(path); try (WritableByteChannel ch = Files.newByteChannel(path, StandardOpenOption.WRITE, StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING)) { if (ch.write(buf) != Integer.BYTES) { throw new IOException("Did not write expected number of bytes."); } } } }