/**
* Copyright (C) 2012 FuseSource, Inc.
* http://fusesource.com
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.fusesource.hawtdispatch.transport;
import org.fusesource.hawtdispatch.Task;
import javax.net.ssl.*;
import java.io.EOFException;
import java.io.IOException;
import java.net.Socket;
import java.net.URI;
import java.nio.ByteBuffer;
import java.nio.channels.*;
import java.security.cert.Certificate;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.*;
import static javax.net.ssl.SSLEngineResult.Status.*;
/**
* An SSL Transport for secure communications.
*
* @author <a href="http://hiramchirino.com">Hiram Chirino</a>
*/
public class SslTransport extends TcpTransport implements SecuredSession {
/**
* Maps uri schemes to a protocol algorithm names.
* Valid algorithm names listed at:
* http://download.oracle.com/javase/6/docs/technotes/guides/security/StandardNames.html#SSLContext
*/
public static String protocol(String scheme) {
if( scheme.equals("tls") ) {
return "TLS";
} else if( scheme.startsWith("tlsv") ) {
return "TLSv"+scheme.substring(4);
} else if( scheme.equals("ssl") ) {
return "SSL";
} else if( scheme.startsWith("sslv") ) {
return "SSLv"+scheme.substring(4);
}
return null;
}
enum ClientAuth {
WANT, NEED, NONE
};
private ClientAuth clientAuth = ClientAuth.WANT;
private String disabledCypherSuites = null;
private String enabledCipherSuites = null;
private SSLContext sslContext;
private SSLEngine engine;
private ByteBuffer readBuffer;
private boolean readUnderflow;
private ByteBuffer writeBuffer;
private boolean writeFlushing;
private ByteBuffer readOverflowBuffer;
private SSLChannel ssl_channel = new SSLChannel();
public void setSSLContext(SSLContext ctx) {
this.sslContext = ctx;
}
/**
* Allows subclasses of TcpTransportFactory to create custom instances of
* TcpTransport.
*/
public static SslTransport createTransport(URI uri) throws Exception {
String protocol = protocol(uri.getScheme());
if( protocol !=null ) {
SslTransport rc = new SslTransport();
rc.setSSLContext(SSLContext.getInstance(protocol));
return rc;
}
return null;
}
public class SSLChannel implements ScatteringByteChannel, GatheringByteChannel {
public int write(ByteBuffer plain) throws IOException {
return secure_write(plain);
}
public int read(ByteBuffer plain) throws IOException {
return secure_read(plain);
}
public boolean isOpen() {
return getSocketChannel().isOpen();
}
public void close() throws IOException {
getSocketChannel().close();
}
public long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
if(offset+length > srcs.length || length<0 || offset<0) {
throw new IndexOutOfBoundsException();
}
long rc=0;
for (int i = 0; i < length; i++) {
ByteBuffer src = srcs[offset+i];
if(src.hasRemaining()) {
rc += write(src);
}
if( src.hasRemaining() ) {
return rc;
}
}
return rc;
}
public long write(ByteBuffer[] srcs) throws IOException {
return write(srcs, 0, srcs.length);
}
public long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
if(offset+length > dsts.length || length<0 || offset<0) {
throw new IndexOutOfBoundsException();
}
long rc=0;
for (int i = 0; i < length; i++) {
ByteBuffer dst = dsts[offset+i];
if(dst.hasRemaining()) {
rc += read(dst);
}
if( dst.hasRemaining() ) {
return rc;
}
}
return rc;
}
public long read(ByteBuffer[] dsts) throws IOException {
return read(dsts, 0, dsts.length);
}
public Socket socket() {
SocketChannel c = channel;
if( c == null ) {
return null;
}
return c.socket();
}
}
public SSLSession getSSLSession() {
return engine==null ? null : engine.getSession();
}
public X509Certificate[] getPeerX509Certificates() {
if( engine==null ) {
return null;
}
try {
ArrayList<X509Certificate> rc = new ArrayList<X509Certificate>();
for( Certificate c:engine.getSession().getPeerCertificates() ) {
if(c instanceof X509Certificate) {
rc.add((X509Certificate) c);
}
}
return rc.toArray(new X509Certificate[rc.size()]);
} catch (SSLPeerUnverifiedException e) {
return null;
}
}
@Override
public void connecting(URI remoteLocation, URI localLocation) throws Exception {
assert engine == null;
engine = sslContext.createSSLEngine(remoteLocation.getHost(), remoteLocation.getPort());
engine.setUseClientMode(true);
super.connecting(remoteLocation, localLocation);
}
@Override
public void connected(SocketChannel channel) throws Exception {
if (engine == null) {
engine = sslContext.createSSLEngine();
engine.setUseClientMode(false);
switch (clientAuth) {
case WANT: engine.setWantClientAuth(true); break;
case NEED: engine.setNeedClientAuth(true); break;
case NONE: engine.setWantClientAuth(false); break;
}
}
if (enabledCipherSuites != null) {
engine.setEnabledCipherSuites(splitOnCommas(enabledCipherSuites));
} else {
engine.setEnabledCipherSuites(engine.getSupportedCipherSuites());
}
if( disabledCypherSuites!=null ) {
String[] disabledList = splitOnCommas(disabledCypherSuites);
ArrayList<String> enabled = new ArrayList<String>();
for (String suite : engine.getEnabledCipherSuites()) {
boolean add = true;
for (String disabled : disabledList) {
if( suite.contains(disabled) ) {
add = false;
break;
}
}
if( add ) {
enabled.add(suite);
}
}
engine.setEnabledCipherSuites(enabled.toArray(new String[enabled.size()]));
}
super.connected(channel);
}
private String[] splitOnCommas(String value) {
ArrayList<String> rc = new ArrayList<String>();
for( String x : value.split(",") ) {
rc.add(x.trim());
}
return rc.toArray(new String[rc.size()]);
}
@Override
protected void initializeChannel() throws Exception {
super.initializeChannel();
SSLSession session = engine.getSession();
readBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
readBuffer.flip();
writeBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
}
@Override
protected void onConnected() throws IOException {
super.onConnected();
engine.beginHandshake();
handshake();
}
@Override
public void flush() {
if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
handshake();
} else {
super.flush();
}
}
@Override
public void drainInbound() {
if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
handshake();
} else {
super.drainInbound();
}
}
/**
* @return true if fully flushed.
* @throws IOException
*/
protected boolean transportFlush() throws IOException {
while (true) {
if(writeFlushing) {
int count = super.getWriteChannel().write(writeBuffer);
if( !writeBuffer.hasRemaining() ) {
writeBuffer.clear();
writeFlushing = false;
suspendWrite();
return true;
} else {
return false;
}
} else {
if( writeBuffer.position()!=0 ) {
writeBuffer.flip();
writeFlushing = true;
resumeWrite();
} else {
return true;
}
}
}
}
private int secure_write(ByteBuffer plain) throws IOException {
if( !transportFlush() ) {
// can't write anymore until the write_secured_buffer gets fully flushed out..
return 0;
}
int rc = 0;
while ( plain.hasRemaining() ^ engine.getHandshakeStatus()==NEED_WRAP ) {
SSLEngineResult result = engine.wrap(plain, writeBuffer);
assert result.getStatus()!= BUFFER_OVERFLOW;
rc += result.bytesConsumed();
if( !transportFlush() || result.getStatus() == CLOSED) {
break;
}
}
if( plain.remaining()==0 && engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
dispatchQueue.execute(new Task() {
public void run() {
handshake();
}
});
}
return rc;
}
private int secure_read(ByteBuffer plain) throws IOException {
int rc=0;
while ( plain.hasRemaining() ^ engine.getHandshakeStatus() == NEED_UNWRAP ) {
if( readOverflowBuffer !=null ) {
if( plain.hasRemaining() ) {
// lets drain the overflow buffer before trying to suck down anymore
// network bytes.
int size = Math.min(plain.remaining(), readOverflowBuffer.remaining());
plain.put(readOverflowBuffer.array(), readOverflowBuffer.position(), size);
readOverflowBuffer.position(readOverflowBuffer.position()+size);
if( !readOverflowBuffer.hasRemaining() ) {
readOverflowBuffer = null;
}
rc += size;
} else {
return rc;
}
} else if( readUnderflow ) {
int count = super.getReadChannel().read(readBuffer);
if( count == -1 ) { // peer closed socket.
if (rc==0) {
return -1;
} else {
return rc;
}
}
if( count==0 ) { // no data available right now.
return rc;
}
// read in some more data, perhaps now we can unwrap.
readUnderflow = false;
readBuffer.flip();
} else {
SSLEngineResult result = engine.unwrap(readBuffer, plain);
rc += result.bytesProduced();
if( result.getStatus() == BUFFER_OVERFLOW ) {
readOverflowBuffer = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize());
result = engine.unwrap(readBuffer, readOverflowBuffer);
if( readOverflowBuffer.position()==0 ) {
readOverflowBuffer = null;
} else {
readOverflowBuffer.flip();
}
}
switch( result.getStatus() ) {
case CLOSED:
if (rc==0) {
engine.closeInbound();
return -1;
} else {
return rc;
}
case OK:
if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
dispatchQueue.execute(new Task() {
public void run() {
handshake();
}
});
}
break;
case BUFFER_UNDERFLOW:
readBuffer.compact();
readUnderflow = true;
break;
case BUFFER_OVERFLOW:
throw new AssertionError("Unexpected case.");
}
}
}
return rc;
}
public void handshake() {
try {
if( !transportFlush() ) {
return;
}
switch (engine.getHandshakeStatus()) {
case NEED_TASK:
final Runnable task = engine.getDelegatedTask();
if( task!=null ) {
blockingExecutor.execute(new Task() {
public void run() {
task.run();
dispatchQueue.execute(new Task() {
public void run() {
if (isConnected()) {
handshake();
}
}
});
}
});
}
break;
case NEED_WRAP:
secure_write(ByteBuffer.allocate(0));
break;
case NEED_UNWRAP:
if( secure_read(ByteBuffer.allocate(0)) == -1) {
throw new EOFException("Peer disconnected during ssl handshake");
}
break;
case FINISHED:
case NOT_HANDSHAKING:
break;
default:
System.err.println("Unexpected ssl engine handshake status: "+ engine.getHandshakeStatus());
break;
}
} catch (IOException e ) {
onTransportFailure(e);
} finally {
if( engine.getHandshakeStatus() == NOT_HANDSHAKING ) {
drainOutboundSource.merge(1);
super.drainInbound();
}
}
}
public ReadableByteChannel getReadChannel() {
return ssl_channel;
}
public WritableByteChannel getWriteChannel() {
return ssl_channel;
}
public String getClientAuth() {
return clientAuth.name();
}
public void setClientAuth(String clientAuth) {
this.clientAuth = ClientAuth.valueOf(clientAuth.toUpperCase());
}
public String getDisabledCypherSuites() {
return disabledCypherSuites;
}
public String getEnabledCypherSuites() {
return enabledCipherSuites;
}
public void setDisabledCypherSuites(String disabledCypherSuites) {
this.disabledCypherSuites = disabledCypherSuites;
}
public void setEnabledCypherSuites(String enabledCypherSuites) {
this.enabledCipherSuites = enabledCypherSuites;
}
}