/* * Copyright 2016, Google Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are * met: * * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above * copyright notice, this list of conditions and the following disclaimer * in the documentation and/or other materials provided with the * distribution. * * * Neither the name of Google Inc. nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ package io.grpc.android.integrationtest; import android.annotation.TargetApi; import android.net.SSLCertificateSocketFactory; import android.os.Build; import android.support.annotation.Nullable; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import io.grpc.okhttp.NegotiationType; import io.grpc.okhttp.OkHttpChannelBuilder; import java.io.InputStream; import java.lang.reflect.Method; import java.security.KeyStore; import java.security.cert.CertificateFactory; import java.security.cert.X509Certificate; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactory; import javax.security.auth.x500.X500Principal; /** * A helper class to create a OkHttp based channel. */ class TesterOkHttpChannelBuilder { public static ManagedChannel build(String host, int port, @Nullable String serverHostOverride, boolean useTls, @Nullable InputStream testCa, @Nullable String androidSocketFactoryTls) { ManagedChannelBuilder<?> channelBuilder = ManagedChannelBuilder.forAddress(host, port) .maxInboundMessageSize(16 * 1024 * 1024); if (serverHostOverride != null) { // Force the hostname to match the cert the server uses. channelBuilder.overrideAuthority(serverHostOverride); } if (useTls) { try { SSLSocketFactory factory; if (androidSocketFactoryTls != null) { factory = getSslCertificateSocketFactory(testCa, androidSocketFactoryTls); } else { factory = getSslSocketFactory(testCa); } ((OkHttpChannelBuilder) channelBuilder).negotiationType(NegotiationType.TLS); ((OkHttpChannelBuilder) channelBuilder).sslSocketFactory(factory); } catch (Exception e) { throw new RuntimeException(e); } } else { channelBuilder.usePlaintext(true); } return channelBuilder.build(); } private static SSLSocketFactory getSslSocketFactory(@Nullable InputStream testCa) throws Exception { if (testCa == null) { return (SSLSocketFactory) SSLSocketFactory.getDefault(); } SSLContext context = SSLContext.getInstance("TLS"); context.init(null, getTrustManagers(testCa) , null); return context.getSocketFactory(); } @TargetApi(14) private static SSLCertificateSocketFactory getSslCertificateSocketFactory( @Nullable InputStream testCa, String androidSocketFatoryTls) throws Exception { if (Build.VERSION.SDK_INT < Build.VERSION_CODES.ICE_CREAM_SANDWICH /* API level 14 */) { throw new RuntimeException( "android_socket_factory_tls doesn't work with API level less than 14."); } SSLCertificateSocketFactory factory = (SSLCertificateSocketFactory) SSLCertificateSocketFactory.getDefault(5000 /* Timeout in ms*/); // Use HTTP/2.0 byte[] h2 = "h2".getBytes(); byte[][] protocols = new byte[][]{h2}; if (androidSocketFatoryTls.equals("alpn")) { Method setAlpnProtocols = factory.getClass().getDeclaredMethod("setAlpnProtocols", byte[][].class); setAlpnProtocols.invoke(factory, new Object[] { protocols }); } else if (androidSocketFatoryTls.equals("npn")) { Method setNpnProtocols = factory.getClass().getDeclaredMethod("setNpnProtocols", byte[][].class); setNpnProtocols.invoke(factory, new Object[]{protocols}); } else { throw new RuntimeException("Unknown protocol: " + androidSocketFatoryTls); } if (testCa != null) { factory.setTrustManagers(getTrustManagers(testCa)); } return factory; } private static TrustManager[] getTrustManagers(InputStream testCa) throws Exception { KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); ks.load(null); CertificateFactory cf = CertificateFactory.getInstance("X.509"); X509Certificate cert = (X509Certificate) cf.generateCertificate(testCa); X500Principal principal = cert.getSubjectX500Principal(); ks.setCertificateEntry(principal.getName("RFC2253"), cert); // Set up trust manager factory to use our key store. TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); trustManagerFactory.init(ks); return trustManagerFactory.getTrustManagers(); } }