/* * JBoss, Home of Professional Open Source. * Copyright 2014 Red Hat, Inc., and individual contributors * as indicated by the @author tags. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.undertow.protocols.ssl; import io.undertow.UndertowLogger; import java.io.ByteArrayOutputStream; import java.lang.reflect.Field; import java.lang.reflect.Method; import java.nio.BufferUnderflowException; import java.nio.ByteBuffer; import java.security.MessageDigest; import java.util.List; import java.util.concurrent.atomic.AtomicReference; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLException; import javax.net.ssl.SSLSession; /** * SSLEngine wrapper that provides some super hacky ALPN support on JDK8. * * Even though this is a nasty hack that relies on JDK internals it is still preferable to modifying the boot class path. * * It is expected to work with all JDK8 versions, however this cannot be guaranteed if the SSL internals are changed * in an incompatible way. * * This class will go away once JDK8 is no longer in use. * * @author Stuart Douglas */ public class ALPNHackSSLEngine extends SSLEngine { public static final boolean ENABLED; private static final Field HANDSHAKER; private static final Field HANDSHAKER_PROTOCOL_VERSION; private static final Field HANDSHAKE_HASH; private static final Field HANDSHAKE_HASH_VERSION; private static final Method HANDSHAKE_HASH_UPDATE; private static final Method HANDSHAKE_HASH_PROTOCOL_DETERMINED; private static final Field HANDSHAKE_HASH_DATA; private static final Field HANDSHAKE_HASH_FIN_MD; private static final Class<?> SSL_ENGINE_IMPL_CLASS; static { boolean enabled = true; Field handshaker; Field handshakeHash; Field handshakeHashVersion; Field handshakeHashData; Field handshakeHashFinMd; Field protocolVersion; Method handshakeHashUpdate; Method handshakeHashProtocolDetermined; Class<?> sslEngineImpleClass; try { Class<?> protocolVersionClass = Class.forName("sun.security.ssl.ProtocolVersion", true, ClassLoader.getSystemClassLoader()); sslEngineImpleClass = Class.forName("sun.security.ssl.SSLEngineImpl", true, ClassLoader.getSystemClassLoader()); handshaker = sslEngineImpleClass.getDeclaredField("handshaker"); handshaker.setAccessible(true); handshakeHash = handshaker.getType().getDeclaredField("handshakeHash"); handshakeHash.setAccessible(true); protocolVersion = handshaker.getType().getDeclaredField("protocolVersion"); protocolVersion.setAccessible(true); handshakeHashVersion = handshakeHash.getType().getDeclaredField("version"); handshakeHashVersion.setAccessible(true); handshakeHashUpdate = handshakeHash.getType().getDeclaredMethod("update", byte[].class, int.class, int.class); handshakeHashUpdate.setAccessible(true); handshakeHashProtocolDetermined = handshakeHash.getType().getDeclaredMethod("protocolDetermined", protocolVersionClass); handshakeHashProtocolDetermined.setAccessible(true); handshakeHashData = handshakeHash.getType().getDeclaredField("data"); handshakeHashData.setAccessible(true); handshakeHashFinMd = handshakeHash.getType().getDeclaredField("finMD"); handshakeHashFinMd.setAccessible(true); } catch (Exception e) { UndertowLogger.ROOT_LOGGER.debug("JDK8 ALPN Hack failed ", e); enabled = false; handshaker = null; handshakeHash = null; handshakeHashVersion = null; handshakeHashUpdate = null; handshakeHashProtocolDetermined = null; handshakeHashData = null; handshakeHashFinMd = null; protocolVersion = null; sslEngineImpleClass = null; } ENABLED = enabled && !Boolean.getBoolean("io.undertow.disable-jdk8-alpn"); HANDSHAKER = handshaker; HANDSHAKE_HASH = handshakeHash; HANDSHAKE_HASH_PROTOCOL_DETERMINED = handshakeHashProtocolDetermined; HANDSHAKE_HASH_VERSION = handshakeHashVersion; HANDSHAKE_HASH_UPDATE = handshakeHashUpdate; HANDSHAKE_HASH_DATA = handshakeHashData; HANDSHAKE_HASH_FIN_MD = handshakeHashFinMd; HANDSHAKER_PROTOCOL_VERSION = protocolVersion; SSL_ENGINE_IMPL_CLASS = sslEngineImpleClass; } private final SSLEngine delegate; //ALPN Hack specific variables private boolean unwrapHelloSeen = false; private boolean ourHelloSent = false; private ALPNHackServerByteArrayOutputStream alpnHackServerByteArrayOutputStream; private ALPNHackClientByteArrayOutputStream ALPNHackClientByteArrayOutputStream; private List<String> applicationProtocols; private String selectedApplicationProtocol; private ByteBuffer bufferedWrapData; public ALPNHackSSLEngine(SSLEngine delegate) { this.delegate = delegate; } public static boolean isEnabled(SSLEngine engine) { if(!ENABLED) { return false; } return SSL_ENGINE_IMPL_CLASS.isAssignableFrom(engine.getClass()); } @Override public SSLEngineResult wrap(ByteBuffer[] byteBuffers, int i, int i1, ByteBuffer byteBuffer) throws SSLException { if(bufferedWrapData != null) { int prod = bufferedWrapData.remaining(); byteBuffer.put(bufferedWrapData); bufferedWrapData = null; return new SSLEngineResult(SSLEngineResult.Status.OK, SSLEngineResult.HandshakeStatus.NEED_WRAP, 0, prod); } int pos = byteBuffer.position(); int limit = byteBuffer.limit(); SSLEngineResult res = delegate.wrap(byteBuffers, i, i1, byteBuffer); if(!ourHelloSent && res.bytesProduced() > 0) { if(delegate.getUseClientMode() && applicationProtocols != null && !applicationProtocols.isEmpty()) { ourHelloSent = true; ALPNHackClientByteArrayOutputStream = replaceClientByteOutput(delegate); ByteBuffer newBuf = byteBuffer.duplicate(); newBuf.flip(); byte[] data = new byte[newBuf.remaining()]; newBuf.get(data); byte[] newData = ALPNHackClientHelloExplorer.rewriteClientHello(data, applicationProtocols); if(newData != null) { byte[] clientHelloMesage = new byte[newData.length - 5]; System.arraycopy(newData, 5, clientHelloMesage, 0 , clientHelloMesage.length); ALPNHackClientByteArrayOutputStream.setSentClientHello(clientHelloMesage); byteBuffer.clear(); byteBuffer.put(newData); } } else if (!getUseClientMode()) { if(selectedApplicationProtocol != null && alpnHackServerByteArrayOutputStream != null) { byte[] newServerHello = alpnHackServerByteArrayOutputStream.getServerHello(); //this is the new server hello, it will be part of the first TLS plaintext record if (newServerHello != null) { byteBuffer.flip(); List<ByteBuffer> records = ALPNHackServerHelloExplorer.extractRecords(byteBuffer); ByteBuffer newData = ALPNHackServerHelloExplorer.createNewOutputRecords(newServerHello, records); byteBuffer.position(pos); //erase the data byteBuffer.limit(limit); if (newData.remaining() > byteBuffer.remaining()) { int old = newData.limit(); newData.limit(newData.position() + byteBuffer.remaining()); res = new SSLEngineResult(res.getStatus(), res.getHandshakeStatus(), res.bytesConsumed(), newData.remaining()); byteBuffer.put(newData); newData.limit(old); bufferedWrapData = newData; } else { res = new SSLEngineResult(res.getStatus(), res.getHandshakeStatus(), res.bytesConsumed(), newData.remaining()); byteBuffer.put(newData); } } } } } if(res.bytesProduced() > 0) { ourHelloSent = true; } return res; } @Override public SSLEngineResult unwrap(ByteBuffer dataToUnwrap, ByteBuffer[] byteBuffers, int i, int i1) throws SSLException { if(!unwrapHelloSeen) { if(!delegate.getUseClientMode() && applicationProtocols != null) { try { List<String> result = ALPNHackClientHelloExplorer.exploreClientHello(dataToUnwrap.duplicate()); if(result != null) { for(String protocol : applicationProtocols) { if(result.contains(protocol)) { selectedApplicationProtocol = protocol; break; } } } unwrapHelloSeen = true; } catch (BufferUnderflowException e) { return new SSLEngineResult(SSLEngineResult.Status.BUFFER_UNDERFLOW, SSLEngineResult.HandshakeStatus.NEED_UNWRAP, 0, 0); } } else if(delegate.getUseClientMode() && ALPNHackClientByteArrayOutputStream != null) { if(!dataToUnwrap.hasRemaining()) { return delegate.unwrap(dataToUnwrap, byteBuffers, i, i1); } try { ByteBuffer dup = dataToUnwrap.duplicate(); int type = dup.get(); int major = dup.get(); int minor = dup.get(); if(type == 22 && major == 3 && minor == 3) { //we only care about TLS 1.2 //split up the records, there may be multiple when doing a fast session resume List<ByteBuffer> records = ALPNHackServerHelloExplorer.extractRecords(dataToUnwrap.duplicate()); ByteBuffer firstRecord = records.get(0); //this will be the handshake record final AtomicReference<String> alpnResult = new AtomicReference<>(); ByteBuffer dupFirst = firstRecord.duplicate(); dupFirst.position(firstRecord.position() + 5); ByteBuffer firstLessFraming = dupFirst.duplicate(); byte[] result = ALPNHackServerHelloExplorer.removeAlpnExtensionsFromServerHello(dupFirst, alpnResult); firstLessFraming.limit(dupFirst.position()); unwrapHelloSeen = true; if (result != null) { selectedApplicationProtocol = alpnResult.get(); int newFirstRecordLength = result.length + dupFirst.remaining(); byte[] newFirstRecord = new byte[newFirstRecordLength]; System.arraycopy(result, 0, newFirstRecord, 0, result.length); dupFirst.get(newFirstRecord, result.length, dupFirst.remaining()); dataToUnwrap.position(dataToUnwrap.limit()); byte[] originalFirstRecord = new byte[firstLessFraming.remaining()]; firstLessFraming.get(originalFirstRecord); ByteBuffer newData = ALPNHackServerHelloExplorer.createNewOutputRecords(newFirstRecord, records); dataToUnwrap.clear(); dataToUnwrap.put(newData); dataToUnwrap.flip(); ALPNHackClientByteArrayOutputStream.setReceivedServerHello(originalFirstRecord); } } } catch (BufferUnderflowException e) { return new SSLEngineResult(SSLEngineResult.Status.BUFFER_UNDERFLOW, SSLEngineResult.HandshakeStatus.NEED_UNWRAP, 0, 0); } } } SSLEngineResult res = delegate.unwrap(dataToUnwrap, byteBuffers, i, i1); if(!delegate.getUseClientMode() && selectedApplicationProtocol != null && alpnHackServerByteArrayOutputStream == null) { alpnHackServerByteArrayOutputStream = replaceServerByteOutput(delegate, selectedApplicationProtocol); } return res; } @Override public Runnable getDelegatedTask() { return delegate.getDelegatedTask(); } @Override public void closeInbound() throws SSLException { delegate.closeInbound(); } @Override public boolean isInboundDone() { return delegate.isInboundDone(); } @Override public void closeOutbound() { delegate.closeOutbound(); } @Override public boolean isOutboundDone() { return delegate.isOutboundDone(); } @Override public String[] getSupportedCipherSuites() { return delegate.getSupportedCipherSuites(); } @Override public String[] getEnabledCipherSuites() { return delegate.getEnabledCipherSuites(); } @Override public void setEnabledCipherSuites(String[] strings) { delegate.setEnabledCipherSuites(strings); } @Override public String[] getSupportedProtocols() { return delegate.getSupportedProtocols(); } @Override public String[] getEnabledProtocols() { return delegate.getEnabledProtocols(); } @Override public void setEnabledProtocols(String[] strings) { delegate.setEnabledProtocols(strings); } @Override public SSLSession getSession() { return delegate.getSession(); } @Override public void beginHandshake() throws SSLException { delegate.beginHandshake(); } @Override public SSLEngineResult.HandshakeStatus getHandshakeStatus() { return delegate.getHandshakeStatus(); } @Override public void setUseClientMode(boolean b) { delegate.setUseClientMode(b); } @Override public boolean getUseClientMode() { return delegate.getUseClientMode(); } @Override public void setNeedClientAuth(boolean b) { delegate.setNeedClientAuth(b); } @Override public boolean getNeedClientAuth() { return delegate.getNeedClientAuth(); } @Override public void setWantClientAuth(boolean b) { delegate.setWantClientAuth(b); } @Override public boolean getWantClientAuth() { return delegate.getWantClientAuth(); } @Override public void setEnableSessionCreation(boolean b) { delegate.setEnableSessionCreation(b); } @Override public boolean getEnableSessionCreation() { return delegate.getEnableSessionCreation(); } /** * JDK8 ALPN hack support method. * * These methods will be removed once JDK8 ALPN support is no longer required * @param applicationProtocols */ public void setApplicationProtocols(List<String> applicationProtocols) { this.applicationProtocols = applicationProtocols; } /** * JDK8 ALPN hack support method. * * These methods will be removed once JDK8 ALPN support is no longer required */ public List<String> getApplicationProtocols() { return applicationProtocols; } /** * JDK8 ALPN hack support method. * * These methods will be removed once JDK8 ALPN support is no longer required */ public String getSelectedApplicationProtocol() { return selectedApplicationProtocol; } static ALPNHackServerByteArrayOutputStream replaceServerByteOutput(SSLEngine sslEngine, String selectedAlpnProtocol) { try { Object handshaker = HANDSHAKER.get(sslEngine); Object hash = HANDSHAKE_HASH.get(handshaker); ByteArrayOutputStream existing = (ByteArrayOutputStream) HANDSHAKE_HASH_DATA.get(hash); ALPNHackServerByteArrayOutputStream out = new ALPNHackServerByteArrayOutputStream(sslEngine, existing.toByteArray(), selectedAlpnProtocol); HANDSHAKE_HASH_DATA.set(hash, out); return out; } catch (Exception e) { UndertowLogger.ROOT_LOGGER.debug("Failed to replace hash output stream ", e); return null; } } static ALPNHackClientByteArrayOutputStream replaceClientByteOutput(SSLEngine sslEngine) { try { Object handshaker = HANDSHAKER.get(sslEngine); Object hash = HANDSHAKE_HASH.get(handshaker); ALPNHackClientByteArrayOutputStream out = new ALPNHackClientByteArrayOutputStream(sslEngine); HANDSHAKE_HASH_DATA.set(hash, out); return out; } catch (Exception e) { UndertowLogger.ROOT_LOGGER.debug("Failed to replace hash output stream ", e); return null; } } static void regenerateHashes(SSLEngine sslEngineToHack, ByteArrayOutputStream data, byte[]... hashBytes) { //hack up the SSL engine internal state try { Object handshaker = HANDSHAKER.get(sslEngineToHack); Object hash = HANDSHAKE_HASH.get(handshaker); data.reset(); Object protocolVersion = HANDSHAKER_PROTOCOL_VERSION.get(handshaker); HANDSHAKE_HASH_VERSION.set(hash, -1); HANDSHAKE_HASH_PROTOCOL_DETERMINED.invoke(hash, protocolVersion); MessageDigest digest = (MessageDigest) HANDSHAKE_HASH_FIN_MD.get(hash); digest.reset(); for (byte[] b : hashBytes) { HANDSHAKE_HASH_UPDATE.invoke(hash, b, 0, b.length); } } catch (Exception e) { e.printStackTrace(); //TODO: remove throw new RuntimeException(e); } } }