/* This file is part of VoltDB. * Copyright (C) 2008-2017 VoltDB Inc. * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as * published by the Free Software Foundation, either version 3 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with VoltDB. If not, see <http://www.gnu.org/licenses/>. */ package org.voltcore.network; import java.io.IOException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.nio.channels.SelectionKey; import java.nio.channels.SocketChannel; import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import javax.net.ssl.SSLEngine; import org.voltcore.utils.CoreUtils; import org.voltcore.utils.FlexibleSemaphore; import org.voltcore.utils.ssl.SSLBufferDecrypter; import com.google_voltpatches.common.util.concurrent.ListenableFuture; import io.netty_voltpatches.buffer.ByteBuf; import io.netty_voltpatches.buffer.CompositeByteBuf; import io.netty_voltpatches.buffer.Unpooled; import io.netty_voltpatches.util.IllegalReferenceCountException; public class TLSVoltPort extends VoltPort { public final static int TLS_HEADER_SIZE = 5; private final SSLEngine m_sslEngine; private final SSLBufferDecrypter m_decrypter; private final ConcurrentLinkedDeque<ExecutionException> m_exceptions = new ConcurrentLinkedDeque<>(); private final ConcurrentLinkedDeque<ByteBuffer> m_decrypted = new ConcurrentLinkedDeque<>(); private final FlexibleSemaphore m_inFlight = new FlexibleSemaphore(1); private final CipherExecutor m_ce; private final DecryptionGateway m_dcryptgw; public TLSVoltPort(VoltNetwork network, InputHandler handler, InetSocketAddress remoteAddress, NetworkDBBPool pool, SSLEngine sslEngine, CipherExecutor cipherExecutor) { super(network, handler, remoteAddress, pool); m_ce = cipherExecutor; m_sslEngine = sslEngine; m_decrypter = new SSLBufferDecrypter(sslEngine); m_dcryptgw = new DecryptionGateway(); } /** * this values may change if a TLS session renegotiates its cipher suite */ private int applicationBufferSize() { return m_sslEngine.getSession().getApplicationBufferSize(); } @Override protected void setKey(SelectionKey key) { m_selectionKey = key; m_channel = (SocketChannel)key.channel(); m_readStream = new NIOReadStream(); m_writeStream = new TLSNIOWriteStream( this, m_handler.offBackPressure(), m_handler.onBackPressure(), m_handler.writestreamMonitor(), m_sslEngine, m_ce); m_interestOps = key.interestOps(); } @Override void die() { super.die(); m_dcryptgw.die(); int waitFor = 1 - m_inFlight.availablePermits(); for (int i = 0; i < waitFor; ++i) { try { if (m_inFlight.tryAcquire(1, TimeUnit.SECONDS)) { m_inFlight.release(); break; } } catch (InterruptedException e) { break; } } m_inFlight.drainPermits(); m_inFlight.release(); } private void checkForGatewayExceptions() throws IOException { ExecutionException ee = m_exceptions.poll(); if (ee != null) { IOException ioe = TLSException.ioCause(ee.getCause()); if (ioe == null) { ioe = new IOException("decrypt task failed", ee.getCause()); } throw ioe; } } private void waitForPendingDecrypts() throws IOException { boolean acquired; do { int waitFor = 1 - m_inFlight.availablePermits(); acquired = waitFor == 0; for (int i = 0; i < waitFor && !acquired; ++i) { checkForGatewayExceptions(); try { acquired = m_inFlight.tryAcquire(1, TimeUnit.SECONDS); if (acquired) { m_inFlight.release(); } } catch (InterruptedException e) { throw new IOException("interrupted while waiting for pending decrypts", e); } } } while (!acquired); } String dumpState() { return new StringBuilder(256).append("TLSVoltPort[") .append("availableBytes=").append(readStream().dataAvailable()) .append(", gateway=").append(m_dcryptgw.dumpState()) .append(", decrypted.isEmpty()= ").append(m_decrypted.isEmpty()) .append(", exceptions.isEmpty()= ").append(m_exceptions.isEmpty()) .append(", inFlight=").append(m_inFlight.availablePermits()) .append("]").toString(); } private void waitForPendingEncrypts() throws IOException { ((TLSNIOWriteStream)m_writeStream).waitForPendingEncrypts(); } private final static int MAX_READ = CipherExecutor.FRAME_SIZE << 1; //32 KB private final static int NOT_AVAILABLE = -1; private int m_needed = NOT_AVAILABLE; private final int getMaxRead() { return m_handler.getMaxRead() == 0 ? 0 // in back pressure : m_needed == NOT_AVAILABLE ? MAX_READ : readStream().dataAvailable() > m_needed ? 0 : m_needed - readStream().dataAvailable(); } @Override public void run() throws IOException { try { do { checkForGatewayExceptions(); /* * Have the read stream fill from the network */ if (readyForRead()) { final int maxRead = getMaxRead(); if (maxRead > 0) { int read = fillReadStream(maxRead); if (read > 0) { ByteBuf frameHeader = Unpooled.wrappedBuffer(new byte[TLS_HEADER_SIZE]); while (readStream().dataAvailable() >= TLS_HEADER_SIZE) { NIOReadStream rdstrm = readStream(); rdstrm.peekBytes(frameHeader.array()); m_needed = frameHeader.getShort(3) + TLS_HEADER_SIZE; if (rdstrm.dataAvailable() < m_needed) break; m_dcryptgw.offer(rdstrm.getSlice(m_needed)); m_needed = NOT_AVAILABLE; } } } } if (m_network.isStopping() || m_isShuttingDown) { waitForPendingDecrypts(); } ByteBuffer message = null; while ((message = m_decrypted.poll()) != null) { ++m_messagesRead; m_handler.handleMessage(message, this); } /* * On readiness selection, optimistically assume that write will succeed, * in the common case it will */ drainEncryptedStream(); /* * some encrypt or decrypt task may have finished while this port is running * so enabling write interest would have been muted. Signal is there to * reconsider finished decrypt or encrypt tasks. */ } while (m_signal.compareAndSet(true, false)); } finally { synchronized(m_lock) { assert(m_running == true); m_running = false; } } } private void drainEncryptedStream() throws IOException { TLSNIOWriteStream writeStream = (TLSNIOWriteStream)m_writeStream; writeStream.serializeQueuedWrites(m_pool /* unused and ignored */); if (m_network.isStopping()) { waitForPendingEncrypts(); } synchronized (writeStream) { if (!writeStream.isEmpty()) { writeStream.drainTo(m_channel); } if (writeStream.isEmpty()) { disableWriteSelection(); if (m_isShuttingDown) { m_channel.close(); unregistered(); } } } } /** * if this port is running calls to enableWriteSelection are * ignored by the super class. This signaling artifacts tells * this port run() method to keep polling for finished encrypt * and decrypt taks */ private final AtomicBoolean m_signal = new AtomicBoolean(false); @Override protected void enableWriteSelection() { m_signal.set(true); super.enableWriteSelection(); } @Override void unregistered() { try { waitForPendingDecrypts(); } catch (IOException e) { networkLog.warn("unregistered port had an decryption task drain fault", e); } try { waitForPendingEncrypts(); } catch (IOException e) { networkLog.warn("unregistered port had an encryption task drain fault", e); } m_dcryptgw.releaseDecryptedBuffer(); super.unregistered(); } class ExceptionListener implements Runnable { private final ListenableFuture<?> m_fut; private ExceptionListener(ListenableFuture<?> fut) { m_fut = fut; } @Override public void run() { if (isDead()) return; try { m_fut.get(); } catch (InterruptedException notPossible) { } catch (ExecutionException e) { m_inFlight.release(); networkLog.error("unexpect fault occurred in decrypt task", e.getCause()); m_exceptions.offer(e); } } } /** * Construct used to serialize all the decryption tasks for this port. * it takes a view of the incoming queued buffers (that may span two BBContainers) * and decrypts them. It uses the assembler to gather all frames that comprise * a frame spanning message, otherwise it will enqueue decrypted messages to * the m_descrypted queue. */ class DecryptionGateway implements Runnable { private final byte [] m_overlap = new byte[CipherExecutor.FRAME_SIZE + 2048]; private final ConcurrentLinkedDeque<NIOReadStream.Slice> m_q = new ConcurrentLinkedDeque<>(); private final CompositeByteBuf m_msgbb = Unpooled.compositeBuffer(); private volatile int m_needed = NOT_AVAILABLE; synchronized void offer(NIOReadStream.Slice slice) throws IOException { if (isDead()) { slice.markConsumed().discard(); return; } final boolean wasEmpty = m_q.isEmpty(); m_q.offer(slice); if (wasEmpty) { submitSelf(); } m_inFlight.reducePermits(1); } synchronized void die() { NIOReadStream.Slice slice = null; while ((slice=m_q.poll()) != null) { slice.markConsumed().discard(); } releaseDecryptedBuffer(); } synchronized boolean isEmpty() { return m_q.isEmpty(); } String dumpState() { return new StringBuilder(256).append("DecryptionGateway[isEmpty()=").append(isEmpty()) .append(", isDead()=").append(isDead()) .append(", msgbb=").append(m_msgbb) .append("]").toString(); } private final IOException validateMessageLength(int msgLength) { IOException ioe = null; if (msgLength < 1) { ioe = new BadMessageLength( "Next message length is " + msgLength + " which is less than 1 and is nonsense"); } if (msgLength > MAX_MESSAGE_LENGTH) { ioe = new BadMessageLength( "Next message length is " + msgLength + " which is greater then the hard coded " + "max of " + MAX_MESSAGE_LENGTH + ". Break up the work into smaller chunks (2 megabytes is reasonable) " + "and send as multiple messages or stored procedure invocations"); } assert msgLength > 1 : "invalid negative or zero message length header value"; return ioe; } void releaseDecryptedBuffer() { if (m_msgbb.refCnt() > 0) try { m_msgbb.release(); } catch (IllegalReferenceCountException ignoreIt) { } } @Override public void run() { final NIOReadStream.Slice slice = m_q.peek(); if (slice == null) return; ByteBuf src = slice.bb; if (isDead()) synchronized(this) { slice.markConsumed().discard(); m_q.poll(); releaseDecryptedBuffer(); return; } ByteBuffer [] slicebbarr = slice.bb.nioBuffers(); // if frame overlaps two buffers then copy it to the overlap buffer // and use that instead for the unwrap src buffer if (slicebbarr.length > 1) { src = Unpooled.wrappedBuffer(m_overlap).clear(); slice.bb.readBytes(src, slice.bb.readableBytes()); slicebbarr[0] = src.nioBuffer(); } final int appBuffSz = applicationBufferSize(); ByteBuf dest = m_ce.allocator().buffer(appBuffSz).writerIndex(appBuffSz); ByteBuffer destjbb = dest.nioBuffer(); int decryptedBytes = 0; try { decryptedBytes = m_decrypter.tlsunwrap(slicebbarr[0], destjbb); } catch (TLSException e) { m_inFlight.release(); dest.release(); m_exceptions.offer(new ExecutionException("fragment decrypt task failed", e)); networkLog.error("fragment decrypt task failed", e); enableWriteSelection(); return; } assert !slicebbarr[0].hasRemaining() : "decrypter did not wholly consume the source buffer"; // src buffer is wholly consumed if (!isDead()) { if (decryptedBytes > 0) { dest.writerIndex(destjbb.limit()); m_msgbb.addComponent(true, dest); } else { // the TLS frame was consumed by the call to engines unwrap but it // did not yield any content dest.release(); } int read = 0; while (m_msgbb.readableBytes() >= getNeededBytes()) { if (m_needed == NOT_AVAILABLE) { m_needed = m_msgbb.readInt(); IOException ioe = validateMessageLength(m_needed); if (ioe != null) { m_inFlight.release(); m_msgbb.release(); m_exceptions.offer(new ExecutionException("failed message length check", ioe)); networkLog.error("failed message length check", ioe); enableWriteSelection(); } continue; } ByteBuffer bb = ByteBuffer.allocate(m_needed); m_msgbb.readBytes(bb); m_decrypted.offer((ByteBuffer)bb.flip()); ++read; m_needed = NOT_AVAILABLE; } if (read > 0) { m_msgbb.discardReadComponents(); enableWriteSelection(); } } else { // it isDead() dest.release(); releaseDecryptedBuffer(); } synchronized(this) { m_q.poll(); slice.markConsumed().discard(); m_inFlight.release(); if (m_q.peek() != null) { submitSelf(); } } } void submitSelf() { ListenableFuture<?> fut = m_ce.submit(this); fut.addListener(new ExceptionListener(fut), CoreUtils.LISTENINGSAMETHREADEXECUTOR); } private int getNeededBytes() { return m_needed == NOT_AVAILABLE ? 4 : m_needed; } } /** The distinct exception class allows better logging of these unexpected errors. */ static class BadMessageLength extends IOException { private static final long serialVersionUID = 8547352379044459911L; public BadMessageLength(String string) { super(string); } } }