/* This file is part of VoltDB.
* Copyright (C) 2008-2017 VoltDB Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with VoltDB. If not, see <http://www.gnu.org/licenses/>.
*/
package org.voltcore.utils.ssl;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import javax.net.ssl.SSLEngine;
import org.voltcore.network.CipherExecutor;
import org.voltcore.network.TLSException;
import io.netty_voltpatches.buffer.ByteBuf;
import io.netty_voltpatches.buffer.CompositeByteBuf;
import io.netty_voltpatches.buffer.Unpooled;
public class TLSMessagingChannel extends MessagingChannel {
private final SSLEngine m_engine;
private final CipherExecutor m_ce;
private final SSLBufferDecrypter m_decrypter;
private final SSLBufferEncrypter m_encrypter;
public TLSMessagingChannel(SocketChannel socketChannel, SSLEngine engine) {
super(socketChannel);
m_engine = engine;
m_ce = CipherExecutor.valueOf(engine);
m_decrypter = new SSLBufferDecrypter(m_engine);
m_encrypter = new SSLBufferEncrypter(m_engine);
}
/**
* this values may change if a TLS session renegotiates its cipher suite
*/
private int packetBufferSize() {
return m_engine.getSession().getPacketBufferSize();
}
/**
* this values may change if a TLS session renegotiates its cipher suite
*/
private int applicationBufferSize() {
return m_engine.getSession().getApplicationBufferSize();
}
private final static int NOT_AVAILABLE = -1;
private int validateLength(int sz) throws IOException {
if (sz < 1 || sz > (1<<25)) {
throw new IOException("Invalid message length header value: " + sz
+ ". It must be between 1 and " + (1<<25));
}
return sz;
}
@Override
public ByteBuffer readMessage() throws IOException {
final int appsz = applicationBufferSize();
ByteBuf readbuf = m_ce.allocator().ioBuffer(packetBufferSize());
CompositeByteBuf msgbb = Unpooled.compositeBuffer();
try {
ByteBuf clear = m_ce.allocator().buffer(appsz).writerIndex(appsz);
ByteBuffer src, dst;
do {
readbuf.clear();
if (!m_decrypter.readTLSFrame(m_socketChannel, readbuf)) {
return null;
}
src = readbuf.nioBuffer();
dst = clear.nioBuffer();
} while (m_decrypter.tlsunwrap(src, dst) == 0);
msgbb.addComponent(true, clear.writerIndex(dst.limit()));
int needed = msgbb.readableBytes() >= 4 ? validateLength(msgbb.readInt()) : NOT_AVAILABLE;
while (msgbb.readableBytes() < (needed == NOT_AVAILABLE ? 4 : needed)) {
clear = m_ce.allocator().buffer(appsz).writerIndex(appsz);
do {
readbuf.clear();
if (!m_decrypter.readTLSFrame(m_socketChannel, readbuf)) {
return null;
}
src = readbuf.nioBuffer();
dst = clear.nioBuffer();
} while (m_decrypter.tlsunwrap(src, dst) == 0);
msgbb.addComponent(true, clear.writerIndex(dst.limit()));
if (needed == NOT_AVAILABLE && msgbb.readableBytes() >= 4) {
needed = validateLength(msgbb.readInt());
}
}
ByteBuffer retbb = ByteBuffer.allocate(needed);
msgbb.readBytes(retbb);
msgbb.discardReadComponents();
assert !msgbb.isReadable() : "read from unblocked channel that received multiple messages?";
return (ByteBuffer)retbb.flip();
} finally {
readbuf.release();
msgbb.release();
}
}
@Override
public int writeMessage(ByteBuffer message) throws IOException {
if (!message.hasRemaining()) {
return 0;
}
CompositeByteBuf outbuf = Unpooled.compositeBuffer();
ByteBuf msg = Unpooled.wrappedBuffer(message);
final int needed = CipherExecutor.framesFor(msg.readableBytes());
for (int have = 0; have < needed; ++have) {
final int slicesz = Math.min(CipherExecutor.FRAME_SIZE, msg.readableBytes());
ByteBuf clear = msg.readSlice(slicesz).writerIndex(slicesz);
ByteBuf encr = m_ce.allocator().ioBuffer(packetBufferSize()).writerIndex(packetBufferSize());
ByteBuffer src = clear.nioBuffer();
ByteBuffer dst = encr.nioBuffer();
try {
m_encrypter.tlswrap(src, dst);
} catch (TLSException e) {
outbuf.release();
encr.release();
throw new IOException("failed to encrypt tls frame", e);
}
assert !src.hasRemaining() : "encryption wrap did not consume the whole source buffer";
encr.writerIndex(dst.limit());
outbuf.addComponent(true, encr);
}
int bytesWritten = 0;
try {
while (outbuf.isReadable()) {
bytesWritten += outbuf.readBytes(m_socketChannel, outbuf.readableBytes());
}
} catch (IOException e) {
throw e;
} finally {
outbuf.release();
}
message.position(message.position() + msg.readerIndex());
return bytesWritten;
}
}