/*
* Copyright 2014 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.ssl;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.util.CharsetUtil;
import io.netty.util.DomainNameMapping;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import java.net.IDN;
import java.util.List;
import java.util.Locale;
/**
* <p>Enables <a href="https://tools.ietf.org/html/rfc3546#section-3.1">SNI
* (Server Name Indication)</a> extension for server side SSL. For clients
* support SNI, the server could have multiple host name bound on a single IP.
* The client will send host name in the handshake data so server could decide
* which certificate to choose for the host name. </p>
*/
public class SniHandler extends ByteToMessageDecoder {
private static final InternalLogger logger =
InternalLoggerFactory.getInstance(SniHandler.class);
private final DomainNameMapping<SslContext> mapping;
private boolean handshaken;
private volatile String hostname;
private volatile SslContext selectedContext;
/**
* Create a SNI detection handler with configured {@link SslContext}
* maintained by {@link DomainNameMapping}
*
* @param mapping the mapping of domain name to {@link SslContext}
*/
@SuppressWarnings("unchecked")
public SniHandler(DomainNameMapping<? extends SslContext> mapping) {
if (mapping == null) {
throw new NullPointerException("mapping");
}
this.mapping = (DomainNameMapping<SslContext>) mapping;
handshaken = false;
}
/**
* @return the selected hostname
*/
public String hostname() {
return hostname;
}
/**
* @return the selected sslcontext
*/
public SslContext sslContext() {
return selectedContext;
}
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
if (!handshaken && in.readableBytes() >= 5) {
String hostname = sniHostNameFromHandshakeInfo(in);
if (hostname != null) {
hostname = IDN.toASCII(hostname, IDN.ALLOW_UNASSIGNED).toLowerCase(Locale.US);
}
this.hostname = hostname;
// the mapping will return default context when this.hostname is null
selectedContext = mapping.map(hostname);
}
if (handshaken) {
SslHandler sslHandler = selectedContext.newHandler(ctx.alloc());
ctx.pipeline().replace(this, SslHandler.class.getName(), sslHandler);
}
}
private String sniHostNameFromHandshakeInfo(ByteBuf in) {
int readerIndex = in.readerIndex();
try {
int command = in.getUnsignedByte(readerIndex);
// tls, but not handshake command
switch (command) {
case SslConstants.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
case SslConstants.SSL_CONTENT_TYPE_ALERT:
case SslConstants.SSL_CONTENT_TYPE_APPLICATION_DATA:
return null;
case SslConstants.SSL_CONTENT_TYPE_HANDSHAKE:
break;
default:
//not tls or sslv3, do not try sni
handshaken = true;
return null;
}
int majorVersion = in.getUnsignedByte(readerIndex + 1);
// SSLv3 or TLS
if (majorVersion == 3) {
int packetLength = in.getUnsignedShort(readerIndex + 3) + 5;
if (in.readableBytes() >= packetLength) {
// decode the ssl client hello packet
// we have to skip some var-length fields
int offset = readerIndex + 43;
int sessionIdLength = in.getUnsignedByte(offset);
offset += sessionIdLength + 1;
int cipherSuitesLength = in.getUnsignedShort(offset);
offset += cipherSuitesLength + 2;
int compressionMethodLength = in.getUnsignedByte(offset);
offset += compressionMethodLength + 1;
int extensionsLength = in.getUnsignedShort(offset);
offset += 2;
int extensionsLimit = offset + extensionsLength;
while (offset < extensionsLimit) {
int extensionType = in.getUnsignedShort(offset);
offset += 2;
int extensionLength = in.getUnsignedShort(offset);
offset += 2;
// SNI
if (extensionType == 0) {
handshaken = true;
int serverNameType = in.getUnsignedByte(offset + 2);
if (serverNameType == 0) {
int serverNameLength = in.getUnsignedShort(offset + 3);
return in.toString(offset + 5, serverNameLength,
CharsetUtil.UTF_8);
} else {
// invalid enum value
return null;
}
}
offset += extensionLength;
}
handshaken = true;
return null;
} else {
// client hello incomplete
return null;
}
} else {
handshaken = true;
return null;
}
} catch (Throwable e) {
// unexpected encoding, ignore sni and use default
if (logger.isDebugEnabled()) {
logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e);
}
handshaken = true;
return null;
}
}
}