/* * Copyright 2016 LINE Corporation * * LINE Corporation licenses this file to you under the Apache License, * version 2.0 (the "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at: * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. */ package com.linecorp.armeria.client.http; import static com.linecorp.armeria.common.http.HttpSessionProtocols.HTTPS; import static org.junit.Assert.assertEquals; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.UnknownHostException; import java.nio.charset.StandardCharsets; import java.security.cert.X509Certificate; import java.util.Collections; import java.util.List; import java.util.concurrent.CompletableFuture; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; import com.linecorp.armeria.client.ClientFactory; import com.linecorp.armeria.client.Clients; import com.linecorp.armeria.client.SessionOption; import com.linecorp.armeria.client.SessionOptions; import com.linecorp.armeria.common.MediaType; import com.linecorp.armeria.common.http.AggregatedHttpMessage; import com.linecorp.armeria.common.http.HttpRequest; import com.linecorp.armeria.common.http.HttpResponseWriter; import com.linecorp.armeria.common.http.HttpStatus; import com.linecorp.armeria.server.Server; import com.linecorp.armeria.server.ServerBuilder; import com.linecorp.armeria.server.ServiceRequestContext; import com.linecorp.armeria.server.VirtualHostBuilder; import com.linecorp.armeria.server.http.AbstractHttpService; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import io.netty.handler.ssl.util.SelfSignedCertificate; import io.netty.resolver.AddressResolver; import io.netty.resolver.AddressResolverGroup; import io.netty.resolver.InetNameResolver; import io.netty.resolver.InetSocketAddressResolver; import io.netty.util.concurrent.EventExecutor; import io.netty.util.concurrent.Promise; public class HttpClientSniTest { private static final Server server; private static int httpsPort; private static ClientFactory clientFactory; private static final SelfSignedCertificate sscA; private static final SelfSignedCertificate sscB; static { final ServerBuilder sb = new ServerBuilder(); try { sscA = new SelfSignedCertificate("a.com"); sscB = new SelfSignedCertificate("b.com"); sb.port(0, HTTPS); final VirtualHostBuilder a = new VirtualHostBuilder("a.com"); final VirtualHostBuilder b = new VirtualHostBuilder("b.com"); a.serviceAt("/", new SniTestService("a.com")); b.serviceAt("/", new SniTestService("b.com")); a.sslContext(HTTPS, sscA.certificate(), sscA.privateKey()); b.sslContext(HTTPS, sscB.certificate(), sscB.privateKey()); sb.virtualHost(a.build()); sb.defaultVirtualHost(b.build()); } catch (Exception e) { throw new Error(e); } server = sb.build(); } @BeforeClass public static void init() throws Exception { server.start().get(); httpsPort = server.activePorts().values().stream() .filter(p -> p.protocol() == HTTPS).findAny().get().localAddress() .getPort(); clientFactory = new HttpClientFactory(SessionOptions.of( SessionOption.TRUST_MANAGER_FACTORY.newValue(InsecureTrustManagerFactory.INSTANCE), SessionOption.ADDRESS_RESOLVER_GROUP.newValue(new DummyAddressResolverGroup()))); } @AfterClass public static void destroy() throws Exception { CompletableFuture.runAsync(() -> { clientFactory.close(); server.stop(); sscA.delete(); sscB.delete(); }); } @Test public void testMatch() throws Exception { testMatch("a.com"); testMatch("b.com"); } private static void testMatch(String fqdn) throws Exception { assertEquals(fqdn + ": CN=" + fqdn, get(fqdn)); } @Test public void testMismatch() throws Exception { testMismatch("127.0.0.1"); testMismatch("mismatch.com"); } private static void testMismatch(String fqdn) throws Exception { assertEquals("b.com: CN=b.com", get(fqdn)); } private static String get(String fqdn) throws Exception { HttpClient client = Clients.newClient( clientFactory, "none+https://" + fqdn + ':' + httpsPort, HttpClient.class); AggregatedHttpMessage response = client.get("/").aggregate().get(); assertEquals(HttpStatus.OK, response.headers().status()); return response.content().toString(StandardCharsets.UTF_8); } private static class DummyAddressResolverGroup extends AddressResolverGroup<InetSocketAddress> { @Override protected AddressResolver<InetSocketAddress> newResolver(EventExecutor eventExecutor) { return new InetSocketAddressResolver(eventExecutor, new InetNameResolver(eventExecutor) { @Override protected void doResolve(String hostname, Promise<InetAddress> promise) { try { promise.setSuccess(newAddress(hostname)); } catch (UnknownHostException e) { promise.setFailure(e); } } @Override protected void doResolveAll(String hostname, Promise<List<InetAddress>> promise) { try { promise.setSuccess(Collections.singletonList(newAddress(hostname))); } catch (UnknownHostException e) { promise.setFailure(e); } } private InetAddress newAddress(String hostname) throws UnknownHostException { return InetAddress.getByAddress(hostname, new byte[] { 127, 0, 0, 1 }); } }); } } private static class SniTestService extends AbstractHttpService { private final String domainName; SniTestService(String domainName) { this.domainName = domainName; } @Override protected void doGet(ServiceRequestContext ctx, HttpRequest req, HttpResponseWriter res) { final X509Certificate c = (X509Certificate) ctx.sslSession().getLocalCertificates()[0]; final String name = c.getSubjectX500Principal().getName(); res.respond(HttpStatus.OK, MediaType.PLAIN_TEXT_UTF_8, domainName + ": " + name); } } }