/***************************************************************************
* Copyright (C) 2010 by Fabrizio Montesi <famontesi@gmail.com> *
* *
* This program is free software; you can redistribute it and/or modify *
* it under the terms of the GNU Library General Public License as *
* published by the Free Software Foundation; either version 2 of the *
* License, or (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU General Public License for more details. *
* *
* You should have received a copy of the GNU Library General Public *
* License along with this program; if not, write to the *
* Free Software Foundation, Inc., *
* 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *
* *
* For details about the authors of this software, see the AUTHORS file. *
***************************************************************************/
package jolie.net.ssl;
import java.io.ByteArrayOutputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.URI;
import java.nio.ByteBuffer;
import java.security.KeyManagementException;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.UnrecoverableKeyException;
import java.security.cert.CertificateException;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLEngineResult.Status;
import javax.net.ssl.SSLException;
import javax.net.ssl.TrustManagerFactory;
import jolie.net.CommMessage;
import jolie.net.protocols.CommProtocol;
import jolie.net.protocols.SequentialCommProtocol;
import jolie.runtime.Value;
import jolie.runtime.VariablePath;
/**
* Commodity class for supporting the implementation
* of SSL-based protocols through wrapping.
* @author Fabrizio Montesi
* 2010: complete rewrite
*/
public class SSLProtocol extends SequentialCommProtocol
{
private static final ByteBuffer EMPTY_BYTE_BUFFER = ByteBuffer.allocate( 0 );
private static final int INITIAL_BUFFER_SIZE = 32768;
private static final int MAX_SSL_CONTENT_SIZE = 16384;
private final boolean isClient;
private boolean firstTime;
private final CommProtocol wrappedProtocol;
private SSLEngine sslEngine = null;
private OutputStream outputStream;
private InputStream inputStream;
private ByteBuffer clearInputBuffer = ByteBuffer.allocate( INITIAL_BUFFER_SIZE );
private class SSLInputStream extends InputStream
{
public int read()
throws IOException
{
return SSLProtocol.this.read();
}
}
private class SSLOutputStream extends OutputStream
{
private ByteArrayOutputStream internalStreamBuffer = new ByteArrayOutputStream();
public void write( int b )
throws IOException
{
internalStreamBuffer.write( b );
if ( internalStreamBuffer.size() >= MAX_SSL_CONTENT_SIZE ) {
writeCache();
}
//SSLProtocol.this.write( ByteBuffer.wrap( new byte[] { (byte)b } ) );
}
public void writeCache()
throws IOException
{
SSLProtocol.this.write( ByteBuffer.wrap( internalStreamBuffer.toByteArray() ) );
internalStreamBuffer.reset();
}
@Override
public void flush()
throws IOException
{
writeCache();
SSLProtocol.this.flushOutputStream();
}
}
private class SSLResult
{
private ByteBuffer buffer;
private SSLEngineResult log = null;
public void enlargeBuffer()
{
buffer = ByteBuffer.allocate( buffer.capacity() + INITIAL_BUFFER_SIZE );
}
public SSLResult( int capacity )
{
buffer = ByteBuffer.allocate( capacity );
}
}
public String name()
{
return wrappedProtocol.name() + "s";
}
public SSLProtocol(
VariablePath configurationPath,
URI uri,
CommProtocol wrappedProtocol,
boolean isClient
) {
super( configurationPath );
this.wrappedProtocol = wrappedProtocol;
this.isClient = isClient;
firstTime = true;
clearInputBuffer.limit( 0 );
}
private SSLResult wrap( ByteBuffer source )
throws IOException
{
SSLResult result = new SSLResult( source.capacity() );
result.log = sslEngine.wrap( source, result.buffer );
while ( result.log.getStatus() == Status.BUFFER_OVERFLOW ) {
result.enlargeBuffer();
result.log = sslEngine.wrap( source, result.buffer );
}
if ( result.log.getStatus() == Status.CLOSED ) {
throw new IOException( "Remote party closed SSL connection" );
}
result.buffer.flip();
return result;
}
private String getSSLStringParameter( String parameterName, String defaultValue )
{
if ( hasParameter( "ssl" ) ) {
Value sslParams = getParameterFirstValue( "ssl" );
if ( sslParams.hasChildren( parameterName ) ) {
return sslParams.getFirstChild( parameterName ).strValue();
}
}
return defaultValue;
}
private int getSSLIntegerParameter( String parameterName, int defaultValue )
{
if ( hasParameter( "ssl" ) ) {
Value sslParams = getParameterFirstValue( "ssl" );
if ( sslParams.hasChildren( parameterName ) ) {
return sslParams.getFirstChild( parameterName ).intValue();
}
}
return defaultValue;
}
private void init()
throws IOException
{
// Set default parameters
String protocol = getSSLStringParameter( "protocol", "SSLv3" ),
keyStoreFormat = getSSLStringParameter( "keyStoreFormat", "JKS" ),
trustStoreFormat = getSSLStringParameter( "trustStoreFormat", "JKS" ),
keyStoreFile = getSSLStringParameter( "keyStore", null ),
keyStorePassword = getSSLStringParameter( "keyStorePassword", null ),
trustStoreFile = getSSLStringParameter( "trustStore", System.getProperty( "java.home" ) + "/lib/security/cacerts" ),
trustStorePassword = getSSLStringParameter( "trustStorePassword", null );
if ( keyStoreFile == null && isClient == false ) {
throw new IOException( "Compulsory parameter needed for server mode: ssl.keyStore" );
}
try {
SSLContext context = SSLContext.getInstance( protocol );
KeyStore ks = KeyStore.getInstance( keyStoreFormat );
KeyStore ts = KeyStore.getInstance( trustStoreFormat );
char[] passphrase;
if ( keyStorePassword != null ) {
passphrase = keyStorePassword.toCharArray();
} else {
passphrase = null;
}
if ( keyStoreFile != null ) {
ks.load( new FileInputStream( keyStoreFile ), passphrase );
} else {
ks.load( null, null );
}
KeyManagerFactory kmf = KeyManagerFactory.getInstance( "SunX509" );
kmf.init( ks, passphrase );
if ( trustStorePassword != null ) {
passphrase = trustStorePassword.toCharArray();
} else {
passphrase = null;
}
ts.load( new FileInputStream( trustStoreFile ), passphrase );
TrustManagerFactory tmf = TrustManagerFactory.getInstance( "SunX509" );
tmf.init( ts );
context.init( kmf.getKeyManagers(), tmf.getTrustManagers(), null );
sslEngine = context.createSSLEngine();
sslEngine.setEnabledProtocols( new String[] { protocol } );
sslEngine.setUseClientMode( isClient );
if ( isClient == false ) {
if ( getSSLIntegerParameter( "wantClientAuth", 1 ) > 0 ) {
sslEngine.setWantClientAuth( true );
} else {
sslEngine.setWantClientAuth( false );
}
}
} catch ( NoSuchAlgorithmException e ) {
throw new IOException( e );
} catch ( KeyManagementException e ) {
throw new IOException( e );
} catch ( KeyStoreException e ) {
throw new IOException( e );
} catch ( UnrecoverableKeyException e ) {
throw new IOException( e );
} catch ( CertificateException e ) {
throw new IOException( e );
}
}
private void handshake()
throws IOException, SSLException
{
if ( firstTime ) {
init();
sslEngine.beginHandshake();
firstTime = false;
}
SSLResult result;
Runnable runnable;
while ( sslEngine.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING && sslEngine.getHandshakeStatus() != HandshakeStatus.FINISHED ) {
switch ( sslEngine.getHandshakeStatus() ) {
case NEED_TASK:
while ( (runnable = sslEngine.getDelegatedTask()) != null ) {
runnable.run();
}
break;
case NEED_WRAP:
result = wrap( EMPTY_BYTE_BUFFER );
if ( result.log.bytesProduced() > 0 ) { //need to send result to other side
outputStream.write( result.buffer.array(), 0, result.buffer.limit() );
outputStream.flush();
}
break;
case NEED_UNWRAP:
unwrapFromInputStream( true );
break;
}
}
}
public void send( OutputStream ostream, CommMessage message, InputStream istream )
throws IOException
{
outputStream = ostream;
inputStream = istream;
if ( firstTime ) {
wrappedProtocol.setChannel( this.channel() );
}
SSLOutputStream sslOutputStream = new SSLOutputStream();
InputStream sslInputStream = new SSLInputStream();
wrappedProtocol.send( sslOutputStream, message, sslInputStream );
sslOutputStream.writeCache();
}
private int read()
throws IOException
{
handshakeIfNeeded();
if ( clearInputBuffer.position() < clearInputBuffer.limit() ) {
return clearInputBuffer.get();
}
unwrapFromInputStream( false );
return clearInputBuffer.get();
}
private void handshakeIfNeeded()
throws IOException
{
if ( sslEngine == null ||
( sslEngine.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING &&
sslEngine.getHandshakeStatus() != HandshakeStatus.FINISHED )
) {
handshake();
}
}
private void enlargeClearInputBuffer()
{
// TODO: Maybe we should also check if some compacting would suffice.
ByteBuffer tmp = ByteBuffer.allocate( clearInputBuffer.capacity() + INITIAL_BUFFER_SIZE );
tmp.put( clearInputBuffer );
tmp.flip();
clearInputBuffer = tmp;
}
private void unwrapFromInputStream( boolean forHandshake )
throws IOException
{
SSLEngineResult result;
ByteBuffer cryptBuffer;
ByteArrayOutputStream byteOutputStream = new ByteArrayOutputStream();
boolean keepRun = true;
boolean closed = false;
int oldPosition = clearInputBuffer.position();
clearInputBuffer.position( clearInputBuffer.limit() );
clearInputBuffer.limit( clearInputBuffer.capacity() );
byteOutputStream.write( inputStream.read() );
while( keepRun ) {
cryptBuffer = ByteBuffer.wrap( byteOutputStream.toByteArray() );
result = sslEngine.unwrap( cryptBuffer, clearInputBuffer );
switch( result.getStatus() ) {
case BUFFER_OVERFLOW:
enlargeClearInputBuffer();
oldPosition = clearInputBuffer.position();
clearInputBuffer.position( clearInputBuffer.limit() );
clearInputBuffer.limit( clearInputBuffer.capacity() );
break;
case BUFFER_UNDERFLOW:
byteOutputStream.write( inputStream.read() );
break;
case CLOSED:
keepRun = false;
closed = true;
break;
case OK:
clearInputBuffer.limit( clearInputBuffer.position() );
clearInputBuffer.position( oldPosition );
if ( forHandshake ) {
keepRun = false;
} else {
if ( cryptBuffer.position() >= cryptBuffer.limit() ) {
// If we are here, it means that there are no more packets to receive
keepRun = false;
} else {
cryptBuffer = cryptBuffer.slice();
byteOutputStream = new ByteArrayOutputStream();
byteOutputStream.write( cryptBuffer.array() );
}
}
break;
}
}
if ( closed ) {
throw new IOException( "Other party closed the SSL connection" );
}
}
private void write( ByteBuffer b )
throws IOException
{
handshakeIfNeeded();
SSLResult wrapResult = wrap( b );
if ( wrapResult.log.bytesProduced() > 0 ) {
outputStream.write( wrapResult.buffer.array(), 0, wrapResult.buffer.limit() );
//outputStream.flush();
}
}
private void flushOutputStream()
throws IOException
{
outputStream.flush();
}
public CommMessage recv( InputStream istream, OutputStream ostream )
throws IOException
{
outputStream = ostream;
inputStream = istream;
if ( firstTime ) {
wrappedProtocol.setChannel( this.channel() );
}
return wrappedProtocol.recv( new SSLInputStream(), new SSLOutputStream() );
}
}