/*
* JBoss, Home of Professional Open Source.
* Copyright 2014 Red Hat, Inc., and individual contributors
* as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.undertow.server.protocol.http;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.net.ssl.SSLEngine;
import org.xnio.ChannelListener;
import org.xnio.IoUtils;
import org.xnio.OptionMap;
import org.xnio.Pool;
import org.xnio.StreamConnection;
import org.xnio.channels.StreamSourceChannel;
import org.xnio.ssl.SslConnection;
import io.undertow.UndertowLogger;
import io.undertow.UndertowMessages;
import io.undertow.UndertowOptions;
import io.undertow.connector.ByteBufferPool;
import io.undertow.connector.PooledByteBuffer;
import io.undertow.protocols.alpn.ALPNManager;
import io.undertow.protocols.alpn.ALPNProvider;
import io.undertow.protocols.ssl.SslConduit;
import io.undertow.protocols.ssl.UndertowXnioSsl;
import io.undertow.server.AggregateConnectorStatistics;
import io.undertow.server.ConnectorStatistics;
import io.undertow.server.DelegateOpenListener;
import io.undertow.server.HttpHandler;
import io.undertow.server.OpenListener;
import io.undertow.server.XnioByteBufferPool;
/**
* Open listener adaptor for ALPN connections
* <p>
* Not a proper open listener as such, but more a mechanism for selecting between them.
*
* @author Stuart Douglas
*/
public class AlpnOpenListener implements ChannelListener<StreamConnection>, OpenListener {
/**
* HTTP/2 required cipher. Not strictly part of ALPN but it can live here for now till we have a better solution.
*/
public static final String REQUIRED_CIPHER = "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256";
public static final String REQUIRED_PROTOCOL = "TLSv1.2";
private final ALPNManager alpnManager = ALPNManager.INSTANCE; //todo: configurable
private final ByteBufferPool bufferPool;
private final Map<String, ListenerEntry> listeners = new HashMap<>();
private String[] protocols;
private final String fallbackProtocol;
private volatile HttpHandler rootHandler;
private volatile OptionMap undertowOptions;
private volatile boolean statisticsEnabled;
private volatile boolean providerLogged;
private volatile boolean alpnFailLogged;
public AlpnOpenListener(Pool<ByteBuffer> bufferPool, OptionMap undertowOptions, DelegateOpenListener httpListener) {
this(bufferPool, undertowOptions, "http/1.1", httpListener);
}
public AlpnOpenListener(Pool<ByteBuffer> bufferPool, OptionMap undertowOptions) {
this(bufferPool, undertowOptions, null, null);
}
public AlpnOpenListener(Pool<ByteBuffer> bufferPool, OptionMap undertowOptions, String fallbackProtocol, DelegateOpenListener fallbackListener) {
this(new XnioByteBufferPool(bufferPool), undertowOptions, fallbackProtocol, fallbackListener);
}
public AlpnOpenListener(ByteBufferPool bufferPool, OptionMap undertowOptions, DelegateOpenListener httpListener) {
this(bufferPool, undertowOptions, "http/1.1", httpListener);
}
public AlpnOpenListener(ByteBufferPool bufferPool) {
this(bufferPool, OptionMap.EMPTY, null, null);
}
public AlpnOpenListener(ByteBufferPool bufferPool, OptionMap undertowOptions) {
this(bufferPool, undertowOptions, null, null);
}
public AlpnOpenListener(ByteBufferPool bufferPool, OptionMap undertowOptions, String fallbackProtocol, DelegateOpenListener fallbackListener) {
this.bufferPool = bufferPool;
this.undertowOptions = undertowOptions;
this.fallbackProtocol = fallbackProtocol;
statisticsEnabled = undertowOptions.get(UndertowOptions.ENABLE_CONNECTOR_STATISTICS, false);
if (fallbackProtocol != null && fallbackListener != null) {
addProtocol(fallbackProtocol, fallbackListener, 0);
}
}
@Override
public HttpHandler getRootHandler() {
return rootHandler;
}
@Override
public void setRootHandler(HttpHandler rootHandler) {
this.rootHandler = rootHandler;
for (Map.Entry<String, ListenerEntry> delegate : listeners.entrySet()) {
delegate.getValue().listener.setRootHandler(rootHandler);
}
}
@Override
public OptionMap getUndertowOptions() {
return undertowOptions;
}
@Override
public void setUndertowOptions(OptionMap undertowOptions) {
if (undertowOptions == null) {
throw UndertowMessages.MESSAGES.argumentCannotBeNull("undertowOptions");
}
this.undertowOptions = undertowOptions;
for (Map.Entry<String, ListenerEntry> delegate : listeners.entrySet()) {
delegate.getValue().listener.setRootHandler(rootHandler);
}
statisticsEnabled = undertowOptions.get(UndertowOptions.ENABLE_CONNECTOR_STATISTICS, false);
}
@Override
public ByteBufferPool getBufferPool() {
return bufferPool;
}
@Override
public ConnectorStatistics getConnectorStatistics() {
if (statisticsEnabled) {
List<ConnectorStatistics> stats = new ArrayList<>();
for (Map.Entry<String, ListenerEntry> l : listeners.entrySet()) {
ConnectorStatistics c = l.getValue().listener.getConnectorStatistics();
if (c != null) {
stats.add(c);
}
}
return new AggregateConnectorStatistics(stats.toArray(new ConnectorStatistics[stats.size()]));
}
return null;
}
private static class ListenerEntry implements Comparable<ListenerEntry> {
final DelegateOpenListener listener;
final int weight;
final String protocol;
ListenerEntry(DelegateOpenListener listener, int weight, String protocol) {
this.listener = listener;
this.weight = weight;
this.protocol = protocol;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof ListenerEntry)) return false;
ListenerEntry that = (ListenerEntry) o;
if (weight != that.weight) return false;
if (!listener.equals(that.listener)) return false;
return protocol.equals(that.protocol);
}
@Override
public int hashCode() {
int result = listener.hashCode();
result = 31 * result + weight;
result = 31 * result + protocol.hashCode();
return result;
}
@Override
public int compareTo(ListenerEntry o) {
return -Integer.compare(this.weight, o.weight);
}
}
public AlpnOpenListener addProtocol(String name, DelegateOpenListener listener, int weight) {
listeners.put(name, new ListenerEntry(listener, weight, name));
List<ListenerEntry> list = new ArrayList<>(listeners.values());
Collections.sort(list);
protocols = new String[list.size()];
for (int i = 0; i < list.size(); ++i) {
protocols[i] = list.get(i).protocol;
}
return this;
}
public void handleEvent(final StreamConnection channel) {
if (UndertowLogger.REQUEST_LOGGER.isTraceEnabled()) {
UndertowLogger.REQUEST_LOGGER.tracef("Opened connection with %s", channel.getPeerAddress());
}
final SslConduit sslConduit = UndertowXnioSsl.getSslConduit((SslConnection) channel);
final SSLEngine sslEngine = sslConduit.getSSLEngine();
if (!engineSupportsHTTP2(sslEngine)) {
if(!alpnFailLogged) {
synchronized (this) {
if(!alpnFailLogged) {
UndertowLogger.REQUEST_LOGGER.debugf("ALPN has been configured however %s is not present or TLS1.2 is not enabled, falling back to default protocol", REQUIRED_CIPHER);
alpnFailLogged = true;
}
}
}
if (fallbackProtocol != null) {
ListenerEntry listener = listeners.get(fallbackProtocol);
if (listener != null) {
listener.listener.handleEvent(channel);
return;
}
}
}
final ALPNProvider provider = alpnManager.getProvider(sslEngine);
if (provider == null) {
if(!providerLogged) {
synchronized (this) {
if(!providerLogged) {
UndertowLogger.REQUEST_LOGGER.debugf("ALPN has been configured however no provider could be found for engine %s for connector at %s", sslEngine, channel.getLocalAddress());
providerLogged = true;
}
}
}
if (fallbackProtocol != null) {
ListenerEntry listener = listeners.get(fallbackProtocol);
if (listener != null) {
listener.listener.handleEvent(channel);
return;
}
}
UndertowLogger.REQUEST_LOGGER.debugf("No ALPN provider available and no fallback defined");
IoUtils.safeClose(channel);
return;
}
if(!providerLogged) {
synchronized (this) {
if(!providerLogged) {
UndertowLogger.REQUEST_LOGGER.debugf("Using ALPN provider %s for connector at %s", provider, channel.getLocalAddress());
providerLogged = true;
}
}
}
final SSLEngine newEngine = provider.setProtocols(sslEngine, protocols);
sslConduit.setSslEngine(new ALPNLimitingSSLEngine(newEngine, new Runnable() {
@Override
public void run() {
provider.setProtocols(newEngine, new String[]{fallbackProtocol});
}
}));
final AlpnConnectionListener potentialConnection = new AlpnConnectionListener(channel, newEngine, provider);
channel.getSourceChannel().setReadListener(potentialConnection);
potentialConnection.handleEvent(channel.getSourceChannel());
}
public static boolean engineSupportsHTTP2(SSLEngine engine) {
//check to make sure the engine meets the minimum requirements for HTTP/2
//if not then ALPN will not be attempted
String[] protcols = engine.getEnabledProtocols();
boolean found = false;
for(String proto : protcols) {
if(proto.equals(REQUIRED_PROTOCOL)) {
found = true;
break;
}
}
if(!found) {
return false;
}
String[] ciphers = engine.getEnabledCipherSuites();
for (String i : ciphers) {
if (i.equals(REQUIRED_CIPHER)) {
return true;
}
}
return false;
}
private class AlpnConnectionListener implements ChannelListener<StreamSourceChannel> {
private final StreamConnection channel;
private final SSLEngine engine;
private final ALPNProvider provider;
private AlpnConnectionListener(StreamConnection channel, SSLEngine engine, ALPNProvider provider) {
this.channel = channel;
this.engine = engine;
this.provider = provider;
}
@Override
public void handleEvent(StreamSourceChannel source) {
PooledByteBuffer buffer = bufferPool.allocate();
boolean free = true;
try {
while (true) {
int res = channel.getSourceChannel().read(buffer.getBuffer());
if (res == -1) {
IoUtils.safeClose(channel);
return;
}
buffer.getBuffer().flip();
final String selected = provider.getSelectedProtocol(engine);
if (selected != null) {
DelegateOpenListener listener;
if (selected.isEmpty()) {
//alpn not in use
if (fallbackProtocol == null) {
UndertowLogger.REQUEST_IO_LOGGER.noALPNFallback(channel.getPeerAddress());
IoUtils.safeClose(channel);
return;
}
listener = listeners.get(fallbackProtocol).listener;
} else {
listener = listeners.get(selected).listener;
}
source.getReadSetter().set(null);
listener.handleEvent(channel, buffer);
free = false;
return;
} else if (res > 0) {
if (fallbackProtocol == null) {
UndertowLogger.REQUEST_IO_LOGGER.noALPNFallback(channel.getPeerAddress());
IoUtils.safeClose(channel);
return;
}
DelegateOpenListener listener = listeners.get(fallbackProtocol).listener;
source.getReadSetter().set(null);
listener.handleEvent(channel, buffer);
free = false;
return;
} else if (res == 0) {
channel.getSourceChannel().resumeReads();
return;
}
}
} catch (IOException e) {
UndertowLogger.REQUEST_IO_LOGGER.ioException(e);
IoUtils.safeClose(channel);
} finally {
if (free) {
buffer.close();
}
}
}
}
}