/* This file is part of VoltDB. * Copyright (C) 2008-2017 VoltDB Inc. * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as * published by the Free Software Foundation, either version 3 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with VoltDB. If not, see <http://www.gnu.org/licenses/>. */ package org.voltcore.network; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.GatheringByteChannel; import java.util.ArrayList; import java.util.Deque; import java.util.Iterator; import java.util.List; import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import javax.net.ssl.SSLEngine; import org.voltcore.utils.CoreUtils; import org.voltcore.utils.DeferredSerialization; import org.voltcore.utils.EstTime; import org.voltcore.utils.FlexibleSemaphore; import org.voltcore.utils.ssl.SSLBufferEncrypter; import com.google_voltpatches.common.collect.ImmutableList; import com.google_voltpatches.common.util.concurrent.ListenableFuture; import io.netty_voltpatches.buffer.ByteBuf; import io.netty_voltpatches.buffer.CompositeByteBuf; import io.netty_voltpatches.buffer.Unpooled; public class TLSNIOWriteStream extends NIOWriteStream { private final ConcurrentLinkedDeque<ExecutionException> m_exceptions = new ConcurrentLinkedDeque<>(); private final ConcurrentLinkedDeque<EncryptFrame> m_encrypted = new ConcurrentLinkedDeque<>(); private final FlexibleSemaphore m_inFlight = new FlexibleSemaphore(1); private final CompositeByteBuf m_outbuf; private final CipherExecutor m_ce; private final SSLEngine m_sslEngine; private final SSLBufferEncrypter m_encrypter; private final EncryptionGateway m_ecryptgw = new EncryptionGateway(); private int m_queuedBytes = 0; public TLSNIOWriteStream(VoltPort port, Runnable offBackPressureCallback, Runnable onBackPressureCallback, QueueMonitor monitor, SSLEngine engine, CipherExecutor cipherExecutor) { super(port, offBackPressureCallback, onBackPressureCallback, monitor); m_sslEngine = engine; m_ce = cipherExecutor; m_outbuf = Unpooled.compositeBuffer(); m_encrypter = new SSLBufferEncrypter(engine); } /** * this values may change if a TLS session renegotiates its cipher suite */ private int applicationBufferSize() { return m_sslEngine.getSession().getApplicationBufferSize(); } /** * this values may change if a TLS session renegotiates its cipher suite */ private int packetBufferSize() { return m_sslEngine.getSession().getPacketBufferSize(); } @Override int serializeQueuedWrites(NetworkDBBPool pool) throws IOException { checkForGatewayExceptions(); final int frameMax = Math.min(CipherExecutor.FRAME_SIZE, applicationBufferSize()); int processedWrites = 0; final Deque<DeferredSerialization> oldlist = getQueuedWrites(); if (oldlist.isEmpty()) return 0; ByteBuf accum = m_ce.allocator().buffer(frameMax).clear(); DeferredSerialization ds = null; int bytesQueued = 0; int frameMsgs = 0; while ((ds = oldlist.poll()) != null) { ++processedWrites; final int serializedSize = ds.getSerializedSize(); if (serializedSize == DeferredSerialization.EMPTY_MESSAGE_LENGTH) continue; // pack as messages you can inside a TLS frame before you send it to // the encryption gateway if (serializedSize > frameMax) { // frames may contain only one or more whole messages, or only // partial parts of one message. a message may not contain whole // messages and an incomplete partial fragment of one if (accum.writerIndex() > 0) { m_ecryptgw.offer(new EncryptFrame(accum, frameMsgs)); frameMsgs = 0; bytesQueued += accum.writerIndex(); accum = m_ce.allocator().buffer(frameMax).clear(); } ByteBuf big = m_ce.allocator().buffer(serializedSize).writerIndex(serializedSize); ByteBuffer jbb = big.nioBuffer(); ds.serialize(jbb); checkSloppySerialization(jbb, ds); bytesQueued += big.writerIndex(); m_ecryptgw.offer(new EncryptFrame(big, 1)); frameMsgs = 0; continue; } else if (accum.writerIndex() + serializedSize > frameMax) { m_ecryptgw.offer(new EncryptFrame(accum, frameMsgs)); frameMsgs = 0; bytesQueued += accum.writerIndex(); accum = m_ce.allocator().buffer(frameMax).clear(); } ByteBuf packet = accum.slice(accum.writerIndex(), serializedSize); ByteBuffer jbb = packet.nioBuffer(); ds.serialize(jbb); checkSloppySerialization(jbb, ds); accum.writerIndex(accum.writerIndex()+serializedSize); ++frameMsgs; } if (accum.writerIndex() > 0) { m_ecryptgw.offer(new EncryptFrame(accum, frameMsgs)); bytesQueued += accum.writerIndex(); } else { accum.release(); } updateQueued(bytesQueued, true); return processedWrites; } @Override public void updateQueued(int queued, boolean noBackpressureSignal) { super.updateQueued(queued, noBackpressureSignal); m_queuedBytes += queued; } void waitForPendingEncrypts() throws IOException { boolean acquired; do { int waitFor = 1 - m_inFlight.availablePermits(); acquired = waitFor == 0; for (int i = 0; i < waitFor && !acquired; ++i) { checkForGatewayExceptions(); try { acquired = m_inFlight.tryAcquire(1, TimeUnit.SECONDS); if (acquired) { m_inFlight.release(); } } catch (InterruptedException e) { throw new IOException("interrupted while waiting for pending encrypts", e); } } } while (!acquired); } private final List<EncryptFrame> m_partial = new ArrayList<>(); private volatile int m_partialSize = 0; static final class EncryptLedger { final int delta; final int bytes; EncryptLedger(int aDelta, int aBytes) { delta = aDelta; bytes = aBytes; } } /** * Gather all the frames that comprise a whole Volt Message */ private EncryptLedger addFramesForCompleteMessage() { boolean added = false; EncryptFrame frame = null; int bytes = 0; int delta = 0; while (!added && (frame = m_encrypted.poll()) != null) { if (!frame.isLast()) { synchronized(m_partial) { m_partial.add(frame); ++m_partialSize; } continue; } final int partialSize = m_partialSize; if (partialSize > 0) { assert frame.chunks == partialSize + 1 : "partial frame buildup has wrong number of preceeding pieces"; synchronized(m_partial) { for (EncryptFrame frm: m_partial) { m_outbuf.addComponent(true, frm.frame); bytes += frm.frame.readableBytes(); delta += frm.delta; } m_partial.clear(); m_partialSize = 0; } } m_outbuf.addComponent(true, frame.frame); bytes += frame.frame.readableBytes(); delta += frame.delta; m_messagesInOutBuf += frame.msgs; added = true; } return added ? new EncryptLedger(delta, bytes) : null; } @Override synchronized public boolean isEmpty() { return m_queuedWrites.isEmpty() && m_ecryptgw.isEmpty() && m_encrypted.isEmpty() && m_partialSize == 0 && !m_outbuf.isReadable(); } private void checkForGatewayExceptions() throws IOException { ExecutionException ee = m_exceptions.poll(); if (ee != null) { IOException ioe = TLSException.ioCause(ee.getCause()); if (ioe == null) { ioe = new IOException("encrypt task failed", ee.getCause()); } throw ioe; } } private int m_messagesInOutBuf = 0; @Override int drainTo(final GatheringByteChannel channel) throws IOException { int written = 0; int delta = 0; try { long rc = 0; do { checkForGatewayExceptions(); EncryptLedger queued = null; // add to output buffer frames that contain whole messages while ((queued=addFramesForCompleteMessage()) != null) { delta += queued.delta; } rc = m_outbuf.readBytes(channel, m_outbuf.readableBytes()); m_outbuf.discardReadComponents(); written += rc; if (m_outbuf.isReadable()) { if (!m_hadBackPressure) { backpressureStarted(); } } else if (rc > 0) { m_messagesWritten += m_messagesInOutBuf; m_messagesInOutBuf = 0; } } while (rc > 0); } finally { if ( m_outbuf.numComponents() <= 1 && m_hadBackPressure && m_queuedWrites.size() <= m_maxQueuedWritesBeforeBackpressure ) { backpressureEnded(); } if (written > 0 && !isEmpty()) { m_lastPendingWriteTime = EstTime.currentTimeMillis(); } else { m_lastPendingWriteTime = -1L; } if (written > 0) { updateQueued(delta-written, false); m_bytesWritten += written; } else if (delta > 0) { updateQueued(delta, false); } } return written; } String dumpState() { return new StringBuilder(256).append("TLSNIOWriteStream[") .append("isEmpty()=").append(isEmpty()) .append(", encrypted.isEmpty()=").append(m_encrypted.isEmpty()) .append(", exceptions.isEmpty()=").append(m_exceptions.isEmpty()) .append(", gateway=").append(m_ecryptgw.dumpState()) .append(", inFligth=").append(m_inFlight.availablePermits()) .append(", outbuf.readableBytes()=").append(m_outbuf.readableBytes()) .append("]").toString(); } @Override public synchronized int getOutstandingMessageCount() { return m_encrypted.size() + m_queuedWrites.size() + m_partialSize + m_outbuf.numComponents(); } @Override synchronized void shutdown() { m_isShutdown = true; try { DeferredSerialization ds = null; while ((ds = m_queuedWrites.poll()) != null) { ds.cancel(); } int waitFor = 1 - Math.min(m_inFlight.availablePermits(), -4); for (int i = 0; i < waitFor; ++i) { try { if (m_inFlight.tryAcquire(1, TimeUnit.SECONDS)) { m_inFlight.release(); break; } } catch (InterruptedException e) { break; } } m_ecryptgw.die(); EncryptFrame frame = null; while ((frame = m_encrypted.poll()) != null) { frame.frame.release(); } for (EncryptFrame ef: m_partial) { ef.frame.release(); } m_partial.clear(); m_outbuf.release(); // we have to use ledger because we have no idea how much encrypt delta // corresponds to what is left in the output buffer final int unqueue = -m_queuedBytes; updateQueued(unqueue, false); } finally { m_inFlight.drainPermits(); m_inFlight.release(); } } /** * Construct used to serialize all the encryption tasks for this stream. * it takes an encryption request offer, divides it into chunks that * can be handled wholly by SSLEngine wrap, and queues all the * encrypted frames to the m_encrypted queue. All faults are queued * to the m_exceptions queue */ class EncryptionGateway implements Runnable { private final ConcurrentLinkedDeque<EncryptFrame> m_q = new ConcurrentLinkedDeque<>(); private final int COALESCE_THRESHOLD = CipherExecutor.FRAME_SIZE - 4096; synchronized void offer(EncryptFrame frame) throws IOException { final boolean wasEmpty = m_q.isEmpty(); List<EncryptFrame> chunks = frame.chunked( Math.min(CipherExecutor.FRAME_SIZE, applicationBufferSize())); m_q.addAll(chunks); m_inFlight.reducePermits(chunks.size()); if (wasEmpty) { submitSelf(); } } /** * Encryption takes time and the likelihood that some or more encrypt frames * are queued up behind the recently completed frame encryption is high. This * takes queued up small frames and coalesces them into a bigger frame */ private void coalesceEncryptFrames() { EncryptFrame head = m_q.peek(); if (head == null || head.chunks > 1 || head.bb.readableBytes() > COALESCE_THRESHOLD) { return; } m_q.poll(); ByteBuf bb = head.bb; int msgs = head.msgs; int released = 0; head = m_q.peek(); while (head != null && head.chunks == 1 && head.bb.readableBytes() <= bb.writableBytes()) { m_q.poll(); bb.writeBytes(head.bb, head.bb.readableBytes()); head.bb.release(); ++released; msgs += head.msgs; head = m_q.peek(); } m_q.push(new EncryptFrame(bb, 0, msgs)); if (released > 0) { m_inFlight.release(released); } } synchronized int die() { int toUnqueue = 0; EncryptFrame ef = null; while ((ef = m_q.poll()) != null) { toUnqueue += ef.frame.readableBytes(); if (ef.isLast()) { ef.bb.release(); } } return toUnqueue; } String dumpState() { return new StringBuilder(256).append("EncryptionGateway[") .append("q.isEmpty()=").append(m_q.isEmpty()) .append(", partialSize=").append(m_partialSize) .append("]").toString(); } public Iterator<EncryptFrame> iterator() { return ImmutableList.copyOf(m_q).iterator(); } @Override public void run() { EncryptFrame frame = m_q.peek(); if (frame == null) return; ByteBuffer src = frame.frame.nioBuffer(); ByteBuf encr = m_ce.allocator().ioBuffer(packetBufferSize()).writerIndex(packetBufferSize()); ByteBuffer dest = encr.nioBuffer(); try { m_encrypter.tlswrap(src, dest); } catch (TLSException e) { m_inFlight.release(); encr.release(); m_exceptions.offer(new ExecutionException("failed to encrypt frame", e)); networkLog.error("failed to encrypt frame", e); m_port.enableWriteSelection(); return; } assert !src.hasRemaining() : "encryption wrap did not consume the whole source buffer"; int delta = dest.limit() - frame.frame.readableBytes(); encr.writerIndex(dest.limit()); if (!m_isShutdown) { m_encrypted.offer(frame.encrypted(delta, encr)); /* * All interactions with write stream must be protected * with a lock to ensure that interests ops are consistent with * the state of writes queued to the stream. This prevent * lost queued writes where the write is queued * but the write interest op is not set. */ if (frame.isLast()) { m_port.enableWriteSelection(); } } else { encr.release(); return; } synchronized(this) { m_q.poll(); if (frame.isLast()) { frame.bb.release(); } m_inFlight.release(); coalesceEncryptFrames(); if (m_q.peek() != null && !m_isShutdown) { submitSelf(); } } } boolean isEmpty() { return m_q.isEmpty(); } void submitSelf() { ListenableFuture<?> fut = m_ce.submit(this); fut.addListener(new ExceptionListener(fut), CoreUtils.LISTENINGSAMETHREADEXECUTOR); } } class ExceptionListener implements Runnable { private final ListenableFuture<?> m_fut; private ExceptionListener(ListenableFuture<?> fut) { m_fut = fut; } @Override public void run() { if (!m_isShutdown) return; try { m_fut.get(); } catch (InterruptedException notPossible) { } catch (ExecutionException e) { m_inFlight.release(); networkLog.error("unexpect fault occurred in encrypt task", e.getCause()); m_exceptions.offer(e); } } } }