/* * Copyright 2017 TWO SIGMA OPEN SOURCE, LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.twosigma.jupyter.socket; import com.twosigma.beaker.jupyter.msg.JupyterMessages; import com.twosigma.jupyter.Config; import com.twosigma.jupyter.KernelFunctionality; import com.twosigma.jupyter.KernelSockets; import com.twosigma.jupyter.SocketCloseAction; import com.twosigma.jupyter.handler.Handler; import com.twosigma.jupyter.message.MessageSerializer; import com.twosigma.jupyter.message.Header; import com.twosigma.jupyter.message.Message; import com.twosigma.jupyter.security.HashedMessageAuthenticationCode; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.zeromq.ZFrame; import org.zeromq.ZMQ; import org.zeromq.ZMsg; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.LinkedHashMap; import static com.twosigma.beaker.jupyter.msg.JupyterMessages.SHUTDOWN_REPLY; import static com.twosigma.beaker.jupyter.msg.JupyterMessages.SHUTDOWN_REQUEST; import static java.util.Arrays.asList; import static com.twosigma.jupyter.message.MessageSerializer.toJson; public class KernelSocketsZMQ extends KernelSockets { public static final Logger logger = LoggerFactory.getLogger(KernelSocketsZMQ.class); public static final String DELIM = "<IDS|MSG>"; private KernelFunctionality kernel; private SocketCloseAction closeAction; private HashedMessageAuthenticationCode hmac; private ZMQ.Socket hearbeatSocket; private ZMQ.Socket controlSocket; private ZMQ.Socket shellSocket; private ZMQ.Socket iopubSocket; private ZMQ.Socket stdinSocket; private ZMQ.Poller sockets; private ZMQ.Context context; private boolean shutdownSystem = false; public KernelSocketsZMQ(KernelFunctionality kernel, Config configuration, SocketCloseAction closeAction) { this.closeAction = closeAction; this.kernel = kernel; this.hmac = new HashedMessageAuthenticationCode(configuration.getKey()); this.context = ZMQ.context(1); configureSockets(configuration); } private void configureSockets(Config configuration) { final String connection = configuration.getTransport() + "://" + configuration.getHost(); hearbeatSocket = getNewSocket(ZMQ.REP, configuration.getHeartbeat(), connection, context); iopubSocket = getNewSocket(ZMQ.PUB, configuration.getIopub(), connection, context); controlSocket = getNewSocket(ZMQ.ROUTER, configuration.getControl(), connection, context); stdinSocket = getNewSocket(ZMQ.ROUTER, configuration.getStdin(), connection, context); shellSocket = getNewSocket(ZMQ.ROUTER, configuration.getShell(), connection, context); sockets = new ZMQ.Poller(4); sockets.register(controlSocket, ZMQ.Poller.POLLIN); sockets.register(hearbeatSocket, ZMQ.Poller.POLLIN); sockets.register(shellSocket, ZMQ.Poller.POLLIN); sockets.register(stdinSocket, ZMQ.Poller.POLLIN); } public void publish(Message message) { sendMsg(this.iopubSocket, message); } public void send(Message message) { sendMsg(this.shellSocket, message); } private synchronized void sendMsg(ZMQ.Socket socket, Message message) { String header = toJson(message.getHeader()); String parent = toJson(message.getParentHeader()); String meta = toJson(message.getMetadata()); String content = toJson(message.getContent()); String digest = hmac.sign(Arrays.asList(header, parent, meta, content)); ZMsg newZmsg = new ZMsg(); message.getIdentities().forEach(newZmsg::add); newZmsg.add(DELIM); newZmsg.add(digest.getBytes()); newZmsg.add(header.getBytes()); newZmsg.add(parent.getBytes()); newZmsg.add(meta.getBytes()); newZmsg.add(content.getBytes()); newZmsg.send(socket); } private Message readMessage(ZMQ.Socket socket) { ZMsg zmsg = null; Message message = new Message(); try { zmsg = ZMsg.recvMsg(socket); ZFrame[] parts = new ZFrame[zmsg.size()]; zmsg.toArray(parts); byte[] uuid = parts[MessageParts.UUID].getData(); byte[] header = parts[MessageParts.HEADER].getData(); byte[] parent = parts[MessageParts.PARENT].getData(); byte[] metadata = parts[MessageParts.METADATA].getData(); byte[] content = parts[MessageParts.CONTENT].getData(); byte[] expectedSig = parts[MessageParts.HMAC].getData(); verifyDelim(parts[MessageParts.DELIM]); verifySignatures(expectedSig, header, parent, metadata, content); if (uuid != null) { message.getIdentities().add(uuid); } message.setHeader(parse(header, Header.class)); message.setParentHeader(parse(parent, Header.class)); message.setMetadata(parse(metadata, LinkedHashMap.class)); message.setContent(parse(content, LinkedHashMap.class)); } finally { if (zmsg != null) { zmsg.destroy(); } } return message; } @Override public void run() { try { while (!this.isInterrupted()) { sockets.poll(); if (isControlMsg()) { handleControlMsg(); } else if (isHeartbeatMsg()) { handleHeartbeat(); } else if (isShellMsg()) { handleShell(); } else if (isStdinMsg()) { handleStdIn(); } else if (this.isShutdown()) { break; } } } finally { close(); } } private void handleStdIn() { byte[] buffer = stdinSocket.recv(); logger.info("Stdin: {}", new String(buffer)); } private void handleShell() { Message message = readMessage(shellSocket); Handler<Message> handler = kernel.getHandler(message.type()); if (handler != null) { handler.handle(message); } } private void handleHeartbeat() { byte[] buffer = hearbeatSocket.recv(0); hearbeatSocket.send(buffer); } private void handleControlMsg() { Message message = readMessage(controlSocket); JupyterMessages type = message.getHeader().getTypeEnum(); if (type.equals(SHUTDOWN_REQUEST)) { Message reply = new Message(); reply.setHeader(new Header(SHUTDOWN_REPLY, message.getHeader().getSession())); reply.setParentHeader(message.getHeader()); reply.setContent(message.getContent()); sendMsg(controlSocket, reply); shutdown(); } } private ZMQ.Socket getNewSocket(int type, int port, String connection, ZMQ.Context context) { ZMQ.Socket socket = context.socket(type); socket.bind(connection + ":" + String.valueOf(port)); return socket; } private void close() { closeAction.close(); closeSockets(); } private void closeSockets() { try { if (shellSocket != null) { shellSocket.close(); } if (controlSocket != null) { controlSocket.close(); } if (iopubSocket != null) { iopubSocket.close(); } if (stdinSocket != null) { stdinSocket.close(); } if (hearbeatSocket != null) { hearbeatSocket.close(); } context.close(); } catch (Exception e) { } } private void verifySignatures(byte[] expectedSig, byte[] header, byte[] parent, byte[] metadata, byte[] content) { String actualSig = hmac.signBytes(new ArrayList<>(asList(header, parent, metadata, content))); String expectedSigAsString = new String(expectedSig); if (!expectedSigAsString.equals(actualSig)) { throw new RuntimeException("Signatures do not match."); } } private String verifyDelim(ZFrame zframe) { String delim = new String(zframe.getData(), StandardCharsets.UTF_8); if (!DELIM.equals(delim)) { throw new RuntimeException("Delimiter <IDS|MSG> not found"); } return delim; } private boolean isStdinMsg() { return sockets.pollin(3); } private boolean isShellMsg() { return sockets.pollin(2); } private boolean isHeartbeatMsg() { return sockets.pollin(1); } private boolean isControlMsg() { return sockets.pollin(0); } private void shutdown() { this.shutdownSystem = true; } private boolean isShutdown() { return this.shutdownSystem; } private <T> T parse(byte[] bytes, Class<T> theClass) { return bytes != null ? MessageSerializer.parse(new String(bytes), theClass) : null; } }