/* * JLibs: Common Utilities for Java * Copyright (C) 2009 Santhosh Kumar T <santhosh.tekuri@gmail.com> * * This library is free software; you can redistribute it and/or * modify it under the terms of the GNU Lesser General Public * License as published by the Free Software Foundation; either * version 2.1 of the License, or (at your option) any later version. * * This library is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. */ package jlibs.nio; import jlibs.nio.http.expr.Bean; import jlibs.nio.http.expr.UnresolvedException; import jlibs.nio.http.expr.ValueMap; import jlibs.nio.util.Buffers; import jlibs.nio.util.NIOUtil; import javax.net.ssl.*; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; import java.nio.channels.FileChannel; import java.security.cert.X509Certificate; import static java.nio.channels.SelectionKey.OP_READ; import static java.nio.channels.SelectionKey.OP_WRITE; import static javax.net.ssl.SSLEngineResult.HandshakeStatus.*; import static javax.net.ssl.SSLEngineResult.Status.*; import static jlibs.nio.Debugger.IO; import static jlibs.nio.Debugger.println; /** * @author Santhosh Kumar Tekuri */ public final class SSLSocket implements Transport, Bean{ private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocate(0); private final Input peerIn; private final Output peerOut; private final SSLEngine engine; private final int packetBufferSize; private ByteBuffer peerReadBuffer; private ByteBuffer peerWriteBuffer; private ByteBuffer appReadBuffers[]; private int appReadBuffersOffset; private final ByteBuffer buffersArray1[] = { EMPTY_BUFFER }; private final Buffers appWriteBuffers = new Buffers(buffersArray1, 0, 1); public SSLSocket(Input in, Output out, SSLEngine engine) throws IOException{ peerIn = in; transportIn = peerIn.channel().transport; transportIn.peekIn = this; peerOut = out; transportOut = peerOut.channel().transport; transportOut.peekOut = this; this.engine = engine; SSLSession session = engine.getSession(); packetBufferSize = session.getPacketBufferSize(); peerWriteBuffer = Reactor.current().allocator.allocate(2*packetBufferSize); peerWriteBuffer.position(peerWriteBuffer.limit()); ByteBuffer appReadBuffer = Reactor.current().allocator.allocate(session.getApplicationBufferSize()); appReadBuffer.position(appReadBuffer.limit()); appReadBuffers = new ByteBuffer[]{ null, appReadBuffer }; appReadBuffersOffset = 1; if(IO){ println("useClientMode: "+engine.getUseClientMode()+ " applicationBufferSize: "+session.getApplicationBufferSize()+ " packetBufferSize: "+session.getPacketBufferSize() + " handshakeStatus: "+engine.getHandshakeStatus()); } engine.beginHandshake(); selfInterests = engine.getHandshakeStatus()==NEED_UNWRAP ? OP_READ : OP_WRITE; } public SSLSession getSession(){ return engine.getSession(); } private int selfInterests; private long appWrote, appRead; private boolean unwrapUnderflow; @Trace(condition=IO, args="$1") private void run(SSLEngineResult.HandshakeStatus handshakeStatus) throws IOException{ assert handshakeStatus==engine.getHandshakeStatus() || engine.getHandshakeStatus()==NOT_HANDSHAKING; selfInterests = 0; appRead = appWrote = 0; while(!engine.isOutboundDone()){ switch(handshakeStatus){ case NEED_TASK: Runnable task; while((task=engine.getDelegatedTask())!=null) task.run(); handshakeStatus = engine.getHandshakeStatus(); break; case NEED_WRAP: if(peerWriteBuffer==null){ if(peerReadBuffer.hasRemaining()){ peerWriteBuffer = Reactor.current().allocator.allocate(2*packetBufferSize); peerWriteBuffer.position(peerWriteBuffer.limit()); }else{ peerWriteBuffer = peerReadBuffer; peerReadBuffer = null; } }else if(peerWriteBuffer.hasRemaining() && (peerWriteBuffer.capacity()-peerWriteBuffer.limit())<packetBufferSize){ do{ if(peerOut.write(peerWriteBuffer)==0){ if(engine.getHandshakeStatus()==NEED_UNWRAP) selfInterests |= OP_WRITE; return; } }while(peerWriteBuffer.hasRemaining()); } if(peerWriteBuffer.hasRemaining()){ peerWriteBuffer.position(peerWriteBuffer.limit()); peerWriteBuffer.limit(peerWriteBuffer.capacity()); }else peerWriteBuffer.clear(); try{ SSLEngineResult result = engine.wrap(appWriteBuffers.array, appWriteBuffers.offset, appWriteBuffers.length, peerWriteBuffer); if(IO) println(result); assert result.getStatus()!=BUFFER_UNDERFLOW; assert result.getStatus()==OK || (result.getStatus()==CLOSED && engine.isOutboundDone()); appWrote += result.bytesConsumed(); }finally{ peerWriteBuffer.flip(); } handshakeStatus = engine.getHandshakeStatus(); break; case NEED_UNWRAP: if(!writePendingToPeer()) return; if(peerReadBuffer==null){ peerReadBuffer = peerWriteBuffer; peerWriteBuffer = null; } while(true){ if(!peerReadBuffer.hasRemaining() || unwrapUnderflow){ NIOUtil.compact(peerReadBuffer); try{ int read = peerIn.read(peerReadBuffer); if(read==0){ if(engine.getHandshakeStatus()==NEED_UNWRAP) selfInterests |= OP_READ; return; }else if(read==-1){ try{ engine.closeInbound(); }catch(SSLException ignore){ // ignore.printStackTrace(); } break; }else unwrapUnderflow = false; }finally{ peerReadBuffer.flip(); } } ByteBuffer appReadBuffer = appReadBuffers[appReadBuffers.length-1]; if(appReadBuffer.hasRemaining()) return; appReadBuffer.clear(); try{ SSLEngineResult result = engine.unwrap(peerReadBuffer, appReadBuffers, appReadBuffersOffset, appReadBuffers.length-appReadBuffersOffset); if(IO) println(result); if(result.getStatus()==BUFFER_UNDERFLOW) unwrapUnderflow = true; else{ assert result.getStatus()!=BUFFER_OVERFLOW; assert result.getStatus()==OK || (result.getStatus()==CLOSED && engine.isInboundDone()); appRead += result.bytesProduced(); if(appRead>0){ if(isOpen()) return; else appReadBuffer.position(appReadBuffer.limit()); } break; } }finally{ appReadBuffer.flip(); } } handshakeStatus = engine.getHandshakeStatus(); break; case FINISHED: case NOT_HANDSHAKING: if(open){ if(appRead==0 && appReadBuffersOffset!=appReadBuffers.length-1) handshakeStatus = NEED_UNWRAP; else if(appWrote==0 && appWriteBuffers.peekLast()!=EMPTY_BUFFER) handshakeStatus = NEED_WRAP; else return; }else{ engine.closeOutbound(); handshakeStatus = engine.getHandshakeStatus(); } } if(IO) println("handshakeStatus = "+handshakeStatus); } } private boolean writePendingToPeer() throws IOException{ if(peerWriteBuffer!=null && peerWriteBuffer.hasRemaining()){ do{ if(peerOut.write(peerWriteBuffer)==0){ selfInterests |= OP_WRITE; return false; } }while(peerWriteBuffer.hasRemaining()); } return true; } /*-------------------------------------------------[ App Read ]---------------------------------------------------*/ @Override public void addReadInterest(){ if(transportIn.peekIn==this) transportIn.peekInInterested = true; if(appReadBuffers[appReadBuffers.length-1].hasRemaining() || engine.isInboundDone() || (peerReadBuffer!=null && peerReadBuffer.hasRemaining() && !unwrapUnderflow)) transportIn.wakeupReader(); else{ if(selfInterests==0) peerIn.addReadInterest(); else{ if((selfInterests&OP_READ)!=0) peerIn.addReadInterest(); if((selfInterests&OP_WRITE)!=0) peerOut.addWriteInterest(); } } } @Override public int read(ByteBuffer dst) throws IOException{ if(!isOpen()) throw new ClosedChannelException(); if(selfInterests!=0){ run(engine.getHandshakeStatus()); if(selfInterests!=0) return 0; } ByteBuffer appReadBuffer = appReadBuffers[appReadBuffers.length-1]; if(appReadBuffer.hasRemaining()) return NIOUtil.copy(appReadBuffer, dst); if(engine.isInboundDone()){ eof = true; return -1; } if(!dst.hasRemaining()) return 0; appReadBuffers[--appReadBuffersOffset] = dst; try{ run(NEED_UNWRAP); if(appRead==0 && engine.isInboundDone()){ eof = true; return -1; }else return (int)appRead; }finally{ appReadBuffers[appReadBuffersOffset++] = null; } } @Override public long read(ByteBuffer[] dsts) throws IOException{ return read(dsts, 0, dsts.length); } @Override public long read(ByteBuffer[] dsts, int offset, int length) throws IOException{ if(!isOpen()) throw new ClosedChannelException(); if(selfInterests!=0){ run(engine.getHandshakeStatus()); if(selfInterests!=0) return 0; } ByteBuffer appReadBuffer = appReadBuffers[appReadBuffers.length-1]; if(appReadBuffer.hasRemaining()) return NIOUtil.copy(appReadBuffer, dsts, offset, length); if(engine.isInboundDone()){ eof = true; return -1; } while(length>0){ if(dsts[offset].hasRemaining()) break; ++offset; --length; } if(length==0) return 0; if(appReadBuffers.length<length+1){ appReadBuffers = new ByteBuffer[length+1]; appReadBuffersOffset = length; appReadBuffers[length] = appReadBuffer; } appReadBuffersOffset -= length; System.arraycopy(dsts, offset, appReadBuffers, appReadBuffersOffset, length); try{ run(NEED_UNWRAP); if(appRead==0 && engine.isInboundDone()){ eof = true; return -1; }else return appRead; }finally{ for(int i=0; i<length; i++) appReadBuffers[appReadBuffersOffset++] = null; } } private boolean eof; @Override public boolean eof(){ return eof; } @Override public long available(){ return appReadBuffers[appReadBuffers.length-1].remaining(); } @Override public long transferTo(long position, long count, FileChannel target) throws IOException{ ByteBuffer appReadBuffer = appReadBuffers[appReadBuffers.length-1]; if(appReadBuffer.hasRemaining()) return NIOUtil.transfer(appReadBuffer, target, position, count); return target.transferFrom(this, position, count); } /*-------------------------------------------------[ App Write ]---------------------------------------------------*/ @Override public void addWriteInterest(){ if(transportOut.peekOut==this) transportOut.peekOutInterested = true; if(engine.isOutboundDone()) transportOut.wakeupWriter(); else{ if(selfInterests==0) peerOut.addWriteInterest(); else{ if((selfInterests&OP_READ)!=0) peerIn.addReadInterest(); if((selfInterests&OP_WRITE)!=0) peerOut.addWriteInterest(); } } } @Override public int write(ByteBuffer src) throws IOException{ if(!isOpen()) throw new ClosedChannelException(); if(selfInterests!=0){ run(engine.getHandshakeStatus()); if(selfInterests!=0) return 0; } if(engine.isOutboundDone()) throw new IOException("outboundDone"); if(!src.hasRemaining()) return 0; appWriteBuffers.array[0] = src; try{ run(NEED_WRAP); return (int)appWrote; }finally{ appWriteBuffers.array[0] = EMPTY_BUFFER; } } @Override public long write(ByteBuffer[] srcs) throws IOException{ return write(srcs, 0, srcs.length); } @Override public long write(ByteBuffer[] srcs, int offset, int length) throws IOException{ if(!isOpen()) throw new ClosedChannelException(); if(selfInterests!=0){ run(engine.getHandshakeStatus()); if(selfInterests!=0) return 0; } if(engine.isOutboundDone()) throw new IOException("outboundDone"); while(length>0){ if(srcs[offset].hasRemaining()) break; ++offset; --length; } if(length==0) return 0; appWriteBuffers.array = srcs; appWriteBuffers.offset = offset; appWriteBuffers.length = length; try{ run(NEED_WRAP); return (int)appWrote; }finally{ appWriteBuffers.array = buffersArray1; appWriteBuffers.offset = 0; appWriteBuffers.length = 1; } } @Override public boolean flush() throws IOException{ if(peerIn.isOpen() && peerOut.isOpen()){ if(selfInterests!=0){ run(engine.getHandshakeStatus()); if(selfInterests!=0) return false; } if(!writePendingToPeer()) return false; if(!open){ if(peerReadBuffer!=null){ Reactor.current().allocator.free(peerReadBuffer); peerReadBuffer = null; } if(peerWriteBuffer!=null){ Reactor.current().allocator.free(peerWriteBuffer); peerWriteBuffer = null; } Reactor.current().allocator.free(appReadBuffers[appReadBuffers.length-1]); appReadBuffers[appReadBuffers.length-1] = null; try{ peerIn.close(); }finally{ peerOut.close(); } } } return peerOut.flush(); } @Override public long transferFrom(FileChannel src, long position, long count) throws IOException{ return src.transferTo(position, count, this); } /*-------------------------------------------------[ Close ]---------------------------------------------------*/ private boolean open = true; @Override public boolean isOpen(){ return open; } @Override public void close() throws IOException{ if(open){ open = false; ByteBuffer appReadBuffer = appReadBuffers[appReadBuffers.length-1]; if(appReadBuffer.hasRemaining()) appReadBuffer.position(appReadBuffer.limit()); run(engine.getHandshakeStatus()); } } /*-------------------------------------------------[ Transport-Misc ]---------------------------------------------------*/ private final Socket transportIn, transportOut; @Override public Input.Listener getInputListener(){ return transportIn.getInputListener(); } @Override public void setInputListener(Input.Listener listener){ transportIn.setInputListener(listener); } @Override public Output.Listener getOutputListener(){ return transportOut.getOutputListener(); } @Override public void setOutputListener(Output.Listener listener){ transportOut.setOutputListener(listener); } @Override public void wakeupReader(){ transportIn.wakeupReader(); } @Override public void wakeupWriter(){ transportOut.wakeupWriter(); } @Override public Input detachInput(){ return this; } @Override public Output detachOutput(){ return this; } @Override public NBStream channel(){ return transportIn.channel(); //todo } @Override public String toString(){ return "SSLSocket"; } /*-------------------------------------------------[ Bean ]---------------------------------------------------*/ @Override @SuppressWarnings("StringEquality") public Object getField(String name) throws UnresolvedException{ if(name=="protocol") return getSession().getProtocol(); else if(name=="cipher") return getSession().getCipherSuite(); else if(name=="local") return new CertificateBean((X509Certificate)getSession().getLocalCertificates()[0]); else if(name=="peer"){ try{ return new CertificateBean((X509Certificate)getSession().getPeerCertificates()[0]); }catch(SSLPeerUnverifiedException ex){ return null; } }else throw new UnresolvedException(name); } public static class CertificateBean implements Bean{ public final X509Certificate cert; public CertificateBean(X509Certificate cert){ this.cert = cert; } @Override @SuppressWarnings("StringEquality") public Object getField(String name) throws UnresolvedException{ if(name=="sdn") return new DistinguishedName(cert.getSubjectX500Principal().getName()); else if(name=="idn") return new DistinguishedName(cert.getIssuerX500Principal().getName()); return null; } @Override public String toString(){ return cert.toString(); } } public static class DistinguishedName implements ValueMap{ public final String dn; public DistinguishedName(String dn){ this.dn = dn; } @Override public Object getValue(String name){ for(String attr: dn.split(",")){ String str[] = attr.split("="); if(str[0].equals(name)) return str[1]; } return null; } @Override public String toString(){ return dn; } } }