/*
* Licensed to the Apache Software Foundation (ASF) under one or more contributor license
* agreements. See the NOTICE file distributed with this work for additional information regarding
* copyright ownership. The ASF licenses this file to You under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance with the License. You may obtain a
* copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*/
package org.apache.geode.internal.net;
import static com.jayway.awaitility.Awaitility.*;
import static org.apache.geode.distributed.ConfigurationProperties.*;
import static org.apache.geode.internal.security.SecurableCommunicationChannel.*;
import static org.assertj.core.api.Assertions.*;
import com.jayway.awaitility.Awaitility;
import com.sun.tools.hat.internal.model.StackTrace;
import java.io.File;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketTimeoutException;
import java.net.URL;
import java.util.Properties;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.geode.internal.security.SecurableCommunicationChannel;
import org.apache.geode.test.junit.categories.MembershipTest;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.contrib.java.lang.system.RestoreSystemProperties;
import org.junit.experimental.categories.Category;
import org.junit.rules.ErrorCollector;
import org.junit.rules.TemporaryFolder;
import org.junit.rules.TestName;
import org.apache.geode.distributed.internal.DistributionConfig;
import org.apache.geode.distributed.internal.DistributionConfigImpl;
import org.apache.geode.internal.FileUtil;
import org.apache.geode.test.junit.categories.IntegrationTest;
/**
* Integration tests for SocketCreatorFactory with SSL.
* <p>
* <p>
* Renamed from {@code JSSESocketJUnitTest}.
*
* @see ClientSocketFactoryIntegrationTest
*/
@Category({IntegrationTest.class, MembershipTest.class})
public class SSLSocketIntegrationTest {
private static final String MESSAGE = SSLSocketIntegrationTest.class.getName() + " Message";
private AtomicReference<String> messageFromClient = new AtomicReference<>();
private DistributionConfig distributionConfig;
private SocketCreator socketCreator;
private InetAddress localHost;
private Thread serverThread;
private ServerSocket serverSocket;
private Socket clientSocket;
@Rule
public ErrorCollector errorCollector = new ErrorCollector();
@Rule
public RestoreSystemProperties restoreSystemProperties = new RestoreSystemProperties();
@Rule
public TemporaryFolder temporaryFolder = new TemporaryFolder();
@Rule
public TestName testName = new TestName();
@Before
public void setUp() throws Exception {
File keystore = findTestKeystore();
System.setProperty("javax.net.ssl.trustStore", keystore.getCanonicalPath());
System.setProperty("javax.net.ssl.trustStorePassword", "password");
System.setProperty("javax.net.ssl.keyStore", keystore.getCanonicalPath());
System.setProperty("javax.net.ssl.keyStorePassword", "password");
Properties properties = new Properties();
properties.setProperty(MCAST_PORT, "0");
properties.setProperty(CLUSTER_SSL_ENABLED, "true");
properties.setProperty(CLUSTER_SSL_REQUIRE_AUTHENTICATION, "true");
properties.setProperty(CLUSTER_SSL_CIPHERS, "any");
properties.setProperty(CLUSTER_SSL_PROTOCOLS, "TLSv1.2");
this.distributionConfig = new DistributionConfigImpl(properties);
SocketCreatorFactory.setDistributionConfig(this.distributionConfig);
this.socketCreator = SocketCreatorFactory.getSocketCreatorForComponent(CLUSTER);
this.localHost = InetAddress.getLocalHost();
}
@After
public void tearDown() throws Exception {
if (this.clientSocket != null) {
this.clientSocket.close();
}
if (this.serverSocket != null) {
this.serverSocket.close();
}
if (this.serverThread != null && this.serverThread.isAlive()) {
this.serverThread.interrupt();
}
SocketCreatorFactory.close();
}
@Test
public void socketCreatorShouldUseSsl() throws Exception {
assertThat(this.socketCreator.useSSL()).isTrue();
}
@Test
public void securedSocketTransmissionShouldWork() throws Exception {
this.serverSocket = this.socketCreator.createServerSocket(0, 0, this.localHost);
this.serverThread = startServer(this.serverSocket);
int serverPort = this.serverSocket.getLocalPort();
this.clientSocket = this.socketCreator.connectForServer(this.localHost, serverPort);
// transmit expected string from Client to Server
ObjectOutputStream output = new ObjectOutputStream(this.clientSocket.getOutputStream());
output.writeObject(MESSAGE);
output.flush();
// this is the real assertion of this test
await().atMost(1, TimeUnit.MINUTES)
.until(() -> assertThat(this.messageFromClient.get()).isEqualTo(MESSAGE));
}
@Test
public void configureClientSSLSocketCanTimeOut() throws Exception {
final Semaphore serverCoordination = new Semaphore(0);
// configure a non-SSL server socket. We will connect
// a client SSL socket to it and demonstrate that the
// handshake times out
final ServerSocket serverSocket = new ServerSocket();
serverSocket.bind(new InetSocketAddress(SocketCreator.getLocalHost(), 0));
Thread serverThread = new Thread() {
public void run() {
serverCoordination.release();
try (Socket clientSocket = serverSocket.accept()) {
System.out.println("server thread accepted a connection");
serverCoordination.acquire();
} catch (Exception e) {
System.err.println("accept failed");
e.printStackTrace();
}
try {
serverSocket.close();
} catch (IOException e) {
// ignored
}
System.out.println("server thread is exiting");
}
};
serverThread.setName("SocketCreatorJUnitTest serverSocket thread");
serverThread.setDaemon(true);
serverThread.start();
serverCoordination.acquire();
SocketCreator socketCreator =
SocketCreatorFactory.getSocketCreatorForComponent(SecurableCommunicationChannel.SERVER);
int serverSocketPort = serverSocket.getLocalPort();
try {
Awaitility.await("connect to server socket").atMost(30, TimeUnit.SECONDS).until(() -> {
try {
Socket clientSocket = socketCreator.connectForClient(
SocketCreator.getLocalHost().getHostAddress(), serverSocketPort, 2000);
clientSocket.close();
System.err.println(
"client successfully connected to server but should not have been able to do so");
return false;
} catch (SocketTimeoutException e) {
// we need to verify that this timed out in the handshake
// code
System.out.println("client connect attempt timed out - checking stack trace");
StackTraceElement[] trace = e.getStackTrace();
for (StackTraceElement element : trace) {
if (element.getMethodName().equals("configureClientSSLSocket")) {
System.out.println("client connect attempt timed out in the appropriate method");
return true;
}
}
// it wasn't in the configuration method so we need to try again
} catch (IOException e) {
// server socket may not be in accept() yet, causing a connection-refused
// exception
}
return false;
});
} finally {
serverCoordination.release();
}
}
private File findTestKeystore() throws IOException {
return copyKeystoreResourceToFile("/ssl/trusted.keystore");
}
public File copyKeystoreResourceToFile(final String name) throws IOException {
URL resource = getClass().getResource(name);
assertThat(resource).isNotNull();
File file = this.temporaryFolder.newFile(name.replaceFirst(".*/", ""));
FileUtil.copy(resource, file);
return file;
}
private Thread startServer(final ServerSocket serverSocket) throws Exception {
Thread serverThread = new Thread(new MyThreadGroup(this.testName.getMethodName()), () -> {
try {
Socket socket = serverSocket.accept();
SocketCreatorFactory.getSocketCreatorForComponent(CLUSTER).configureServerSSLSocket(socket);
ObjectInputStream ois = new ObjectInputStream(socket.getInputStream());
messageFromClient.set((String) ois.readObject());
} catch (IOException | ClassNotFoundException e) {
throw new Error(e);
}
}, this.testName.getMethodName() + "-server");
serverThread.start();
return serverThread;
}
private class MyThreadGroup extends ThreadGroup {
public MyThreadGroup(final String name) {
super(name);
}
@Override
public void uncaughtException(final Thread thread, final Throwable throwable) {
errorCollector.addError(throwable);
}
}
}