/*
* Copyright (C) 2016 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okhttp3.internal.tls;
import java.io.IOException;
import java.net.SocketException;
import java.security.GeneralSecurityException;
import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import javax.security.auth.x500.X500Principal;
import okhttp3.Call;
import okhttp3.DelegatingSSLSocketFactory;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import static okhttp3.TestUtil.defaultClient;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
public final class ClientAuthTest {
@Rule public final MockWebServer server = new MockWebServer();
public enum ClientAuth {
NONE, WANTS, NEEDS
}
private HeldCertificate serverRootCa;
private HeldCertificate serverIntermediateCa;
private HeldCertificate serverCert;
private HeldCertificate clientRootCa;
private HeldCertificate clientIntermediateCa;
private HeldCertificate clientCert;
@Before
public void setUp() throws GeneralSecurityException {
serverRootCa = new HeldCertificate.Builder()
.serialNumber("1")
.ca(3)
.commonName("root")
.build();
serverIntermediateCa = new HeldCertificate.Builder()
.issuedBy(serverRootCa)
.ca(2)
.serialNumber("2")
.commonName("intermediate_ca")
.build();
serverCert = new HeldCertificate.Builder()
.issuedBy(serverIntermediateCa)
.serialNumber("3")
.commonName(server.getHostName())
.build();
clientRootCa = new HeldCertificate.Builder()
.serialNumber("1")
.ca(13)
.commonName("root")
.build();
clientIntermediateCa = new HeldCertificate.Builder()
.issuedBy(serverRootCa)
.ca(12)
.serialNumber("2")
.commonName("intermediate_ca")
.build();
clientCert = new HeldCertificate.Builder()
.issuedBy(clientIntermediateCa)
.serialNumber("4")
.commonName("Jethro Willis")
.build();
}
@Test public void clientAuthForWants() throws Exception {
OkHttpClient client = buildClient(clientCert, clientIntermediateCa);
SSLSocketFactory socketFactory = buildServerSslSocketFactory(ClientAuth.WANTS);
server.useHttps(socketFactory, false);
server.enqueue(new MockResponse().setBody("abc"));
Call call = client.newCall(new Request.Builder().url(server.url("/")).build());
Response response = call.execute();
assertEquals(new X500Principal("CN=localhost"), response.handshake().peerPrincipal());
assertEquals(new X500Principal("CN=Jethro Willis"), response.handshake().localPrincipal());
assertEquals("abc", response.body().string());
}
@Test public void clientAuthForNeeds() throws Exception {
OkHttpClient client = buildClient(clientCert, clientIntermediateCa);
SSLSocketFactory socketFactory = buildServerSslSocketFactory(ClientAuth.NEEDS);
server.useHttps(socketFactory, false);
server.enqueue(new MockResponse().setBody("abc"));
Call call = client.newCall(new Request.Builder().url(server.url("/")).build());
Response response = call.execute();
assertEquals(new X500Principal("CN=localhost"), response.handshake().peerPrincipal());
assertEquals(new X500Principal("CN=Jethro Willis"), response.handshake().localPrincipal());
assertEquals("abc", response.body().string());
}
@Test public void clientAuthSkippedForNone() throws Exception {
OkHttpClient client = buildClient(clientCert, clientIntermediateCa);
SSLSocketFactory socketFactory = buildServerSslSocketFactory(ClientAuth.NONE);
server.useHttps(socketFactory, false);
server.enqueue(new MockResponse().setBody("abc"));
Call call = client.newCall(new Request.Builder().url(server.url("/")).build());
Response response = call.execute();
assertEquals(new X500Principal("CN=localhost"), response.handshake().peerPrincipal());
assertEquals(null, response.handshake().localPrincipal());
assertEquals("abc", response.body().string());
}
@Test public void missingClientAuthSkippedForWantsOnly() throws Exception {
OkHttpClient client = buildClient(null, clientIntermediateCa);
SSLSocketFactory socketFactory = buildServerSslSocketFactory(ClientAuth.WANTS);
server.useHttps(socketFactory, false);
server.enqueue(new MockResponse().setBody("abc"));
Call call = client.newCall(new Request.Builder().url(server.url("/")).build());
Response response = call.execute();
assertEquals(new X500Principal("CN=localhost"), response.handshake().peerPrincipal());
assertEquals(null, response.handshake().localPrincipal());
assertEquals("abc", response.body().string());
}
@Test public void missingClientAuthFailsForNeeds() throws Exception {
OkHttpClient client = buildClient(null, clientIntermediateCa);
SSLSocketFactory socketFactory = buildServerSslSocketFactory(ClientAuth.NEEDS);
server.useHttps(socketFactory, false);
Call call = client.newCall(new Request.Builder().url(server.url("/")).build());
try {
call.execute();
fail();
} catch (SSLHandshakeException expected) {
} catch (SocketException expected) {
// JDK 9
}
}
@Test public void invalidClientAuthFails() throws Throwable {
HeldCertificate clientCert2 = new HeldCertificate.Builder()
.serialNumber("4")
.commonName("Jethro Willis")
.build();
OkHttpClient client = buildClient(clientCert2);
SSLSocketFactory socketFactory = buildServerSslSocketFactory(ClientAuth.NEEDS);
server.useHttps(socketFactory, false);
Call call = client.newCall(new Request.Builder().url(server.url("/")).build());
try {
call.execute();
fail();
} catch (SSLHandshakeException expected) {
} catch (SocketException expected) {
// JDK 9
}
}
public OkHttpClient buildClient(HeldCertificate cert, HeldCertificate... chain) {
SslClient.Builder sslClientBuilder = new SslClient.Builder()
.addTrustedCertificate(serverRootCa.certificate);
if (cert != null) {
sslClientBuilder.certificateChain(cert, chain);
}
SslClient sslClient = sslClientBuilder.build();
return defaultClient().newBuilder()
.sslSocketFactory(sslClient.socketFactory, sslClient.trustManager)
.build();
}
public SSLSocketFactory buildServerSslSocketFactory(final ClientAuth clientAuth) {
SslClient serverSslClient = new SslClient.Builder()
.addTrustedCertificate(serverRootCa.certificate)
.addTrustedCertificate(clientRootCa.certificate)
.certificateChain(serverCert, serverIntermediateCa)
.build();
return new DelegatingSSLSocketFactory(serverSslClient.socketFactory) {
@Override protected SSLSocket configureSocket(SSLSocket sslSocket) throws IOException {
if (clientAuth == ClientAuth.NEEDS) {
sslSocket.setNeedClientAuth(true);
} else if (clientAuth == ClientAuth.WANTS) {
sslSocket.setWantClientAuth(true);
}
return super.configureSocket(sslSocket);
}
};
}
}