/* dCache - http://www.dcache.org/ * * Copyright (C) 2014 Deutsches Elektronen-Synchrotron * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as * published by the Free Software Foundation, either version 3 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ package org.dcache.gsi; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLException; import javax.net.ssl.SSLHandshakeException; import java.nio.ByteBuffer; /** * SSLEngine decorator that provides limited capability to inject additional * communication with the client hidden from the caller of the SSLEngine. * * The class is tailored for implementing GSI credential delegation while * separating GSI logic from the low level SSLEngine wrapping and unwrapping * protocol. */ public class InterceptingSSLEngine extends ForwardingSSLEngine { private static final ByteBuffer EMPTY = ByteBuffer.allocate(0); private final SSLEngine delegate; private enum State { SEND, RECEIVE, PASSTHROUGH } private State state = State.PASSTHROUGH; private Callback callback; private ByteBuffer out; private ByteBuffer in; public InterceptingSSLEngine(SSLEngine delegate) { this.delegate = delegate; } @Override protected SSLEngine delegate() { return delegate; } /** * Sets up the engine to receive one SSL frame worth of data and call the callback when done. */ public void receive(Callback callback) { this.state = State.RECEIVE; this.callback = callback; int size = getSession().getApplicationBufferSize(); this.in = ByteBuffer.allocate(size); } /** * Sets the engine up to sends the data in {@code out}. */ public void send(ByteBuffer out) { this.out = out; this.state = State.SEND; } /** * Sets the engine up to sends the data in {@code out} and then receive one SSL frame worth of * data and call the callback. */ public void sendThenReceive(ByteBuffer out, Callback callback) { receive(callback); send(out); } @Override public SSLEngineResult wrap(ByteBuffer[] srcs, int offset, int length, ByteBuffer dst) throws SSLException { SSLEngineResult result; switch (state) { case SEND: result = delegate().wrap(out, dst); if (result.getStatus() != SSLEngineResult.Status.OK) { return result; } if (delegate().getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) { return result; } if (out.hasRemaining()) { return new SSLEngineResult(SSLEngineResult.Status.OK, SSLEngineResult.HandshakeStatus.NEED_WRAP, 0, result.bytesProduced()); } out = null; if (callback != null) { state = State.RECEIVE; return new SSLEngineResult(SSLEngineResult.Status.OK, SSLEngineResult.HandshakeStatus.NEED_UNWRAP, 0, result.bytesProduced()); } else { state = State.PASSTHROUGH; return new SSLEngineResult(SSLEngineResult.Status.OK, SSLEngineResult.HandshakeStatus.FINISHED, result.bytesConsumed(), 0); } case RECEIVE: result = delegate().wrap(EMPTY, dst); if (result.getStatus() != SSLEngineResult.Status.OK) { return result; } if (delegate().getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) { return result; } return new SSLEngineResult(SSLEngineResult.Status.OK, SSLEngineResult.HandshakeStatus.NEED_UNWRAP, 0, result.bytesProduced()); default: return delegate().wrap(srcs, offset, length, dst); } } @Override public SSLEngineResult unwrap(ByteBuffer src, ByteBuffer[] dsts, int offset, int length) throws SSLException { SSLEngineResult result; switch (state) { case SEND: result = delegate().unwrap(src, dsts); if (result.getStatus() != SSLEngineResult.Status.OK && result.getStatus() != SSLEngineResult.Status.BUFFER_UNDERFLOW) { return result; } if (result.bytesProduced() != 0) { delegate().closeOutbound(); throw new SSLHandshakeException("Received unexpected data from client during handshake."); } if (delegate().getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) { return result; } return new SSLEngineResult(SSLEngineResult.Status.OK, SSLEngineResult.HandshakeStatus.NEED_WRAP, result.bytesConsumed(), 0); case RECEIVE: result = delegate().unwrap(src, in); if (result.getStatus() == SSLEngineResult.Status.BUFFER_OVERFLOW) { delegate().closeOutbound(); throw new SSLHandshakeException("Received over sized data from client during handshake."); } if (result.getStatus() != SSLEngineResult.Status.OK) { return result; } if (delegate().getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) { return result; } if (result.bytesProduced() == 0) { return new SSLEngineResult(SSLEngineResult.Status.OK, SSLEngineResult.HandshakeStatus.NEED_UNWRAP, result.bytesConsumed(), 0); } Callback callback = this.callback; ByteBuffer in = this.in; this.callback = null; this.in = null; this.state = State.PASSTHROUGH; try { callback.call(in); } catch (SSLException e) { delegate().closeOutbound(); throw e; } switch (state) { case SEND: return new SSLEngineResult(SSLEngineResult.Status.OK, SSLEngineResult.HandshakeStatus.NEED_WRAP, result.bytesConsumed(), 0); case RECEIVE: return new SSLEngineResult(SSLEngineResult.Status.OK, SSLEngineResult.HandshakeStatus.NEED_UNWRAP, result.bytesConsumed(), 0); default: return new SSLEngineResult(SSLEngineResult.Status.OK, SSLEngineResult.HandshakeStatus.FINISHED, result.bytesConsumed(), 0); } default: return delegate().unwrap(src, dsts, offset, length); } } @Override public void closeInbound() throws SSLException { state = State.PASSTHROUGH; callback = null; out = null; in = null; delegate().closeInbound(); } @Override public void closeOutbound() { state = State.PASSTHROUGH; delegate().closeOutbound(); } @Override public SSLEngineResult.HandshakeStatus getHandshakeStatus() { SSLEngineResult.HandshakeStatus handshakeStatus = delegate().getHandshakeStatus(); if (handshakeStatus != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) { return handshakeStatus; } switch (state) { case SEND: return SSLEngineResult.HandshakeStatus.NEED_WRAP; case RECEIVE: return SSLEngineResult.HandshakeStatus.NEED_UNWRAP; } return handshakeStatus; } public interface Callback { void call(ByteBuffer buffer) throws SSLException; } }