/**
* Copyright (C) 2003 Alexander Kout
* Originally from the jFxp project (http://jfxp.sourceforge.net/).
* Copied with permission June 11, 2012 by Femi Omojola (fomojola@ideasynthesis.com).
*/
package org.java_websocket;
import java.io.IOException;
import java.net.Socket;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.nio.channels.SelectableChannel;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLEngineResult.Status;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
/**
* Implements the relevant portions of the SocketChannel interface with the
* SSLEngine wrapper.
*/
public class SSLSocketChannel2 implements ByteChannel, WrappedByteChannel {
/**
* This object is used to feed the {@link SSLEngine}'s wrap and unwrap
* methods during the handshake phase.
**/
protected static ByteBuffer emptybuffer = ByteBuffer.allocate(0);
protected ExecutorService exec;
protected List<Future<?>> tasks;
/** raw payload incomming */
protected ByteBuffer inData;
/** encrypted data outgoing */
protected ByteBuffer outCrypt;
/** encrypted data incoming */
protected ByteBuffer inCrypt;
/** the underlying channel */
protected SocketChannel socketChannel;
/** used to set interestOP SelectionKey.OP_WRITE for the underlying channel */
protected SelectionKey selectionKey;
protected SSLEngine sslEngine;
protected SSLEngineResult readEngineResult;
protected SSLEngineResult writeEngineResult;
/**
* Should be used to count the buffer allocations. But because of #190 where
* HandshakeStatus.FINISHED is not properly returned by nio wrap/unwrap this
* variable is used to check whether {@link #createBuffers(SSLSession)}
* needs to be called.
**/
protected int bufferallocations = 0;
public SSLSocketChannel2(SocketChannel channel, SSLEngine sslEngine,
ExecutorService exec, SelectionKey key) throws IOException {
if (channel == null || sslEngine == null || exec == null)
throw new IllegalArgumentException("parameter must not be null");
this.socketChannel = channel;
this.sslEngine = sslEngine;
this.exec = exec;
readEngineResult = writeEngineResult = new SSLEngineResult(
Status.BUFFER_UNDERFLOW, sslEngine.getHandshakeStatus(), 0, 0); // init
// to
// prevent
// NPEs
tasks = new ArrayList<Future<?>>(3);
if (key != null) {
key.interestOps(key.interestOps() | SelectionKey.OP_WRITE);
this.selectionKey = key;
}
createBuffers(sslEngine.getSession());
// kick off handshake
socketChannel.write(wrap(emptybuffer));// initializes res
processHandshake();
}
@Override
public void close() throws IOException {
sslEngine.closeOutbound();
sslEngine.getSession().invalidate();
if (socketChannel.isOpen())
socketChannel.write(wrap(emptybuffer));// FIXME what if not all
// bytes can be written
socketChannel.close();
exec.shutdownNow();
}
public SelectableChannel configureBlocking(boolean b) throws IOException {
return socketChannel.configureBlocking(b);
}
public boolean connect(SocketAddress remote) throws IOException {
return socketChannel.connect(remote);
}
protected void consumeDelegatedTasks() {
Runnable task;
while ((task = sslEngine.getDelegatedTask()) != null) {
tasks.add(exec.submit(task));
// task.run();
}
}
private void consumeFutureUninterruptible(Future<?> f) {
try {
boolean interrupted = false;
while (true) {
try {
f.get();
break;
} catch (InterruptedException e) {
interrupted = true;
}
}
if (interrupted)
Thread.currentThread().interrupt();
} catch (ExecutionException e) {
throw new RuntimeException(e);
}
}
protected void createBuffers(SSLSession session) {
int appBufferMax = session.getApplicationBufferSize();
int netBufferMax = session.getPacketBufferSize();
if (inData == null) {
inData = ByteBuffer.allocate(appBufferMax);
outCrypt = ByteBuffer.allocate(netBufferMax);
inCrypt = ByteBuffer.allocate(netBufferMax);
} else {
if (inData.capacity() != appBufferMax)
inData = ByteBuffer.allocate(appBufferMax);
if (outCrypt.capacity() != netBufferMax)
outCrypt = ByteBuffer.allocate(netBufferMax);
if (inCrypt.capacity() != netBufferMax)
inCrypt = ByteBuffer.allocate(netBufferMax);
}
inData.rewind();
inData.flip();
inCrypt.rewind();
inCrypt.flip();
outCrypt.rewind();
outCrypt.flip();
bufferallocations++;
}
public boolean finishConnect() throws IOException {
return socketChannel.finishConnect();
}
@Override
public boolean isBlocking() {
return socketChannel.isBlocking();
}
public boolean isConnected() {
return socketChannel.isConnected();
}
private boolean isHandShakeComplete() {
HandshakeStatus status = sslEngine.getHandshakeStatus();
return status == SSLEngineResult.HandshakeStatus.FINISHED
|| status == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING;
}
public boolean isInboundDone() {
return sslEngine.isInboundDone();
}
@Override
public boolean isNeedRead() {
return inData.hasRemaining()
|| (inCrypt.hasRemaining()
&& readEngineResult.getStatus() != Status.BUFFER_UNDERFLOW && readEngineResult
.getStatus() != Status.CLOSED);
}
@Override
public boolean isNeedWrite() {
return outCrypt.hasRemaining() || !isHandShakeComplete(); // FIXME this
// condition
// can cause
// high cpu
// load
// during
// handshaking
// when
// network
// is slow
}
@Override
public boolean isOpen() {
return socketChannel.isOpen();
}
/**
* This method will do whatever necessary to process the sslengine
* handshake. Thats why it's called both from the {@link #read(ByteBuffer)}
* and {@link #write(ByteBuffer)}
**/
private synchronized void processHandshake() throws IOException {
if (sslEngine.getHandshakeStatus() == HandshakeStatus.NOT_HANDSHAKING)
return; // since this may be called either from a reading or a
// writing thread and because this method is synchronized it
// is necessary to double check if we are still handshaking.
if (!tasks.isEmpty()) {
Iterator<Future<?>> it = tasks.iterator();
while (it.hasNext()) {
Future<?> f = it.next();
if (f.isDone()) {
it.remove();
} else {
if (isBlocking())
consumeFutureUninterruptible(f);
return;
}
}
}
if (sslEngine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
if (!isBlocking()
|| readEngineResult.getStatus() == Status.BUFFER_UNDERFLOW) {
inCrypt.compact();
int read = socketChannel.read(inCrypt);
if (read == -1) {
throw new IOException(
"connection closed unexpectedly by peer");
}
inCrypt.flip();
}
inData.compact();
unwrap();
if (readEngineResult.getHandshakeStatus() == HandshakeStatus.FINISHED) {
createBuffers(sslEngine.getSession());
return;
}
}
consumeDelegatedTasks();
if (tasks.isEmpty()
|| sslEngine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_WRAP) {
socketChannel.write(wrap(emptybuffer));
if (writeEngineResult.getHandshakeStatus() == HandshakeStatus.FINISHED) {
createBuffers(sslEngine.getSession());
return;
}
}
assert (sslEngine.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING);// this
// function
// could
// only
// leave
// NOT_HANDSHAKING
// after
// createBuffers
// was
// called
// unless
// #190
// occurs
// which
// means
// that
// nio
// wrap/unwrap
// never
// return
// HandshakeStatus.FINISHED
bufferallocations = 1; // look at variable declaration why this line
// exists and #190. Without this line buffers
// would not be be recreated when #190 AND a
// rehandshake occur.
}
/**
* Blocks when in blocking mode until at least one byte has been decoded.<br>
* When not in blocking mode 0 may be returned.
*
* @return the number of bytes read.
**/
@Override
public int read(ByteBuffer dst) throws IOException {
if (!dst.hasRemaining())
return 0;
if (!isHandShakeComplete()) {
if (isBlocking()) {
while (!isHandShakeComplete()) {
processHandshake();
}
} else {
processHandshake();
if (!isHandShakeComplete()) {
return 0;
}
}
}
// assert ( bufferallocations > 1 ); //see #190
if (bufferallocations <= 1) {
createBuffers(sslEngine.getSession());
}
/*
* 1. When "dst" is smaller than "inData" readRemaining will fill "dst"
* with data decoded in a previous read call. 2. When "inCrypt" contains
* more data than "inData" has remaining space, unwrap has to be called
* on more time(readRemaining)
*/
int purged = readRemaining(dst);
if (purged != 0)
return purged;
/*
* We only continue when we really need more data from the network.
* Thats the case if inData is empty or inCrypt holds to less data than
* necessary for decryption
*/
assert (inData.position() == 0);
inData.clear();
if (!inCrypt.hasRemaining())
inCrypt.clear();
else
inCrypt.compact();
if (isBlocking()
|| readEngineResult.getStatus() == Status.BUFFER_UNDERFLOW)
if (socketChannel.read(inCrypt) == -1) {
return -1;
}
inCrypt.flip();
unwrap();
int transfered = transfereTo(inData, dst);
if (transfered == 0 && isBlocking()) {
return read(dst); // "transfered" may be 0 when not enough bytes
// were received or during rehandshaking
}
return transfered;
}
@Override
public int readMore(ByteBuffer dst) throws SSLException {
return readRemaining(dst);
}
/**
* {@link #read(ByteBuffer)} may not be to leave all buffers(inData,
* inCrypt)
**/
private int readRemaining(ByteBuffer dst) throws SSLException {
if (inData.hasRemaining()) {
return transfereTo(inData, dst);
}
if (!inData.hasRemaining())
inData.clear();
// test if some bytes left from last read (e.g. BUFFER_UNDERFLOW)
if (inCrypt.hasRemaining()) {
unwrap();
int amount = transfereTo(inData, dst);
if (amount > 0)
return amount;
}
return 0;
}
public Socket socket() {
return socketChannel.socket();
}
private int transfereTo(ByteBuffer from, ByteBuffer to) {
int fremain = from.remaining();
int toremain = to.remaining();
if (fremain > toremain) {
// FIXME there should be a more efficient transfer method
int limit = Math.min(fremain, toremain);
for (int i = 0; i < limit; i++) {
to.put(from.get());
}
return limit;
} else {
to.put(from);
return fremain;
}
}
/**
* performs the unwrap operation by unwrapping from {@link #inCrypt} to
* {@link #inData}
**/
private synchronized ByteBuffer unwrap() throws SSLException {
int rem;
do {
rem = inData.remaining();
readEngineResult = sslEngine.unwrap(inCrypt, inData);
} while (readEngineResult.getStatus() == SSLEngineResult.Status.OK
&& (rem != inData.remaining() || sslEngine.getHandshakeStatus() == HandshakeStatus.NEED_UNWRAP));
inData.flip();
return inData;
}
private synchronized ByteBuffer wrap(ByteBuffer b) throws SSLException {
outCrypt.compact();
writeEngineResult = sslEngine.wrap(b, outCrypt);
outCrypt.flip();
return outCrypt;
}
@Override
public int write(ByteBuffer src) throws IOException {
if (!isHandShakeComplete()) {
processHandshake();
return 0;
}
// assert ( bufferallocations > 1 ); //see #190
if (bufferallocations <= 1) {
createBuffers(sslEngine.getSession());
}
int num = socketChannel.write(wrap(src));
return num;
}
@Override
public void writeMore() throws IOException {
write(outCrypt);
}
}