/* dCache - http://www.dcache.org/ * * Copyright (C) 2014-2015 Deutsches Elektronen-Synchrotron * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as * published by the Free Software Foundation, either version 3 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ package org.dcache.gsi; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLException; import javax.net.ssl.SSLSession; import java.nio.ByteBuffer; import java.nio.ByteOrder; import static org.dcache.util.ByteUnit.MiB; /** * SSLEngine decorator that implements legacy GSI framing. * * The class auto-detects whether the client is using the framing format and * responds in kind. */ public class GsiFrameEngine extends ForwardingSSLEngine { private static final ByteBuffer EMPTY = ByteBuffer.allocate(0); private static final int MAX_LEN = MiB.toBytes(32); private final ServerGsiEngine gsiEngine; private SSLEngine currentDelegate; public GsiFrameEngine(ServerGsiEngine delegate) { gsiEngine = delegate; currentDelegate = new FrameDetectingEngine(); } /** * Determines if a given header is a SSLv3 packet * (has a SSL header) or a backward compatible version of TLS * using the same header format. * * @return true if the header is a SSLv3 header. False, otherwise. */ private static boolean isSSLv3Packet(byte[] header) { return header[0] >= 20 && header[0] <= 26 && (header[1] == 3 || (header[1] == 2 && header[2] == 0)); } /** * Determines if a given header is a SSLv2 client or server hello packet * * @return true if the header is such a SSLv2 client or server hello * packet. False, otherwise. */ private static boolean isSSLv2HelloPacket(byte[] header) { return ((header[0] & 0x80) != 0 && (header[2] == 1 || header[2] == 4)); } @Override protected SSLEngine delegate() { return currentDelegate; } private class FrameDetectingEngine extends ForwardingSSLEngine { private final byte[] header = new byte[4]; @Override protected SSLEngine delegate() { return gsiEngine; } @Override public SSLEngineResult wrap(ByteBuffer[] srcs, int offset, int length, ByteBuffer dst) throws SSLException { throw new SSLException("Cannot wrap during frame detecting phase."); } public SSLEngineResult unwrap(ByteBuffer src, ByteBuffer[] dsts, int offset, int length) throws SSLException { if (src.remaining() < 4) { return new SSLEngineResult(SSLEngineResult.Status.BUFFER_UNDERFLOW, getHandshakeStatus(), 0, 0); } src.mark(); try { src.get(header); if (isSSLv3Packet(header)) { currentDelegate = gsiEngine; } else if (isSSLv2HelloPacket(header)) { currentDelegate = gsiEngine; } else { currentDelegate = new FrameEngine(); } } finally { src.reset(); } return currentDelegate.unwrap(src, dsts, offset, length); } } private class FrameEngine extends ForwardingSSLEngine { private ByteBuffer buffer = EMPTY; private final SSLSession session = new Session(); @Override protected SSLEngine delegate() { return gsiEngine; } @Override public SSLEngineResult wrap(ByteBuffer[] srcs, int offset, int length, ByteBuffer dst) throws SSLException { ByteBuffer tmp = ByteBuffer.allocate(dst.remaining() - 4); SSLEngineResult result = delegate().wrap(srcs, offset, length, tmp); if (result.bytesProduced() == 0) { return result; } else { dst.order(ByteOrder.BIG_ENDIAN); dst.putInt(result.bytesProduced()); dst.put(tmp); return new SSLEngineResult(result.getStatus(), result.getHandshakeStatus(), result.bytesConsumed(), 4 + result.bytesProduced()); } } @Override public SSLEngineResult unwrap(ByteBuffer src, ByteBuffer[] dsts, int offset, int length) throws SSLException { int bytesConsumed = read(src); int bytesProduced = 0; SSLEngineResult result; do { result = delegate().unwrap(buffer, dsts, offset, length); bytesProduced += result.bytesProduced(); } while (result.getStatus() == SSLEngineResult.Status.OK); return new SSLEngineResult(result.getStatus(), result.getHandshakeStatus(), bytesConsumed, bytesProduced); } private int read(ByteBuffer src) throws SSLException { int bytesConsumed = 0; if (src.remaining() >= 4) { src.mark(); src.order(ByteOrder.BIG_ENDIAN); int len = src.getInt(); if (len > MAX_LEN) { closeOutbound(); throw new SSLException("Token length " + len + " > " + MAX_LEN); } else if (len < 0) { closeOutbound(); throw new SSLException("Token length " + len + " < 0"); } if (src.remaining() >= len) { int existingBytes = buffer.remaining(); int newBytes = src.remaining(); byte[] newBuffer = new byte[existingBytes + newBytes]; buffer.get(newBuffer, 0, existingBytes); src.get(newBuffer, existingBytes, newBytes); buffer = ByteBuffer.wrap(newBuffer); bytesConsumed = existingBytes + 4; } else { src.reset(); } } return bytesConsumed; } @Override public SSLSession getSession() { return session; } private class Session extends ForwardingSSLSession { @Override protected SSLSession delegate() { return FrameEngine.super.getSession(); } public int getPacketBufferSize() { return super.getPacketBufferSize() + 4; } } } }