/* * Copyright (c) 2008-2012, Hazel Bilisim Ltd. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.hazelcast.client; import com.hazelcast.config.Config; import com.hazelcast.config.GroupConfig; import com.hazelcast.config.SocketInterceptorConfig; import com.hazelcast.core.Hazelcast; import com.hazelcast.core.HazelcastInstance; import com.hazelcast.impl.GroupProperties; import com.hazelcast.nio.MemberSocketInterceptor; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.Socket; import java.util.Random; import java.util.concurrent.atomic.AtomicInteger; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @RunWith(com.hazelcast.util.RandomBlockJUnit4ClassRunner.class) public class SocketInterceptorTest { @After @Before public void cleanup() throws Exception { Hazelcast.shutdownAll(); HazelcastClient.shutdownAll(); } @Test(timeout = 120000) public void testSuccessfulSocketInterceptor() { Config config = new Config(); SocketInterceptorConfig sic = new SocketInterceptorConfig(); MySocketInterceptor mySocketInterceptor = new MySocketInterceptor(true); sic.setImplementation(mySocketInterceptor); config.getNetworkConfig().setSocketInterceptorConfig(sic); HazelcastInstance h1 = Hazelcast.newHazelcastInstance(config); HazelcastInstance h2 = Hazelcast.newHazelcastInstance(config); HazelcastInstance h3 = Hazelcast.newHazelcastInstance(config); HazelcastInstance h4 = Hazelcast.newHazelcastInstance(config); int count = 1000; for (int i = 0; i < count; i++) { h1.getMap("default").put(i, "value" + i); h2.getMap("default").put(i, "value" + i); h3.getMap("default").put(i, "value" + i); h4.getMap("default").put(i, "value" + i); } assertEquals(4, h4.getCluster().getMembers().size()); assertTrue(mySocketInterceptor.getAcceptCallCount() >= 6); assertTrue(mySocketInterceptor.getConnectCallCount() >= 6); assertEquals(4, mySocketInterceptor.getInitCallCount()); assertEquals(0, mySocketInterceptor.getAcceptFailureCount()); assertEquals(0, mySocketInterceptor.getConnectFailureCount()); ClientConfig clientConfig = new ClientConfig(); clientConfig.setGroupConfig(new GroupConfig("dev", "dev-pass")).addAddress("localhost"); MySocketInterceptor myClientSocketInterceptor = new MySocketInterceptor(true); clientConfig.setSocketInterceptor(myClientSocketInterceptor); HazelcastInstance client = HazelcastClient.newHazelcastClient(clientConfig); for (int i = 0; i < count; i++) { client.getMap("default").put(i, "value" + i); } assertTrue(mySocketInterceptor.getAcceptCallCount() >= 7); assertTrue(mySocketInterceptor.getConnectCallCount() >= 6); assertEquals(1, myClientSocketInterceptor.getConnectCallCount()); assertEquals(0, myClientSocketInterceptor.getAcceptCallCount()); assertEquals(0, mySocketInterceptor.getAcceptFailureCount()); assertEquals(0, mySocketInterceptor.getConnectFailureCount()); assertEquals(0, myClientSocketInterceptor.getAcceptFailureCount()); assertEquals(0, myClientSocketInterceptor.getConnectFailureCount()); } @Test(expected = RuntimeException.class, timeout = 120000) public void testFailingSocketInterceptor() { Config config = new Config(); config.setProperty(GroupProperties.PROP_MAX_JOIN_SECONDS, "3"); SocketInterceptorConfig sic = new SocketInterceptorConfig(); MySocketInterceptor mySocketInterceptor = new MySocketInterceptor(false); sic.setImplementation(mySocketInterceptor); config.getNetworkConfig().setSocketInterceptorConfig(sic); HazelcastInstance h1 = Hazelcast.newHazelcastInstance(config); HazelcastInstance h2 = Hazelcast.newHazelcastInstance(config); } @Test(expected = RuntimeException.class, timeout = 120000) public void testFailingClientSocketInterceptor() { Config config = new Config(); SocketInterceptorConfig sic = new SocketInterceptorConfig(); MySocketInterceptor mySocketInterceptor = new MySocketInterceptor(true); sic.setImplementation(mySocketInterceptor); config.getNetworkConfig().setSocketInterceptorConfig(sic); HazelcastInstance h1 = Hazelcast.newHazelcastInstance(config); HazelcastInstance h2 = Hazelcast.newHazelcastInstance(config); int count = 1000; for (int i = 0; i < count; i++) { h1.getMap("default").put(i, "value" + i); h2.getMap("default").put(i, "value" + i); } assertEquals(2, h2.getCluster().getMembers().size()); assertTrue(mySocketInterceptor.getAcceptCallCount() >= 1); assertTrue(mySocketInterceptor.getConnectCallCount() >= 1); assertEquals(2, mySocketInterceptor.getInitCallCount()); assertEquals(0, mySocketInterceptor.getAcceptFailureCount()); assertEquals(0, mySocketInterceptor.getConnectFailureCount()); ClientConfig clientConfig = new ClientConfig(); clientConfig.setGroupConfig(new GroupConfig("dev", "dev-pass")).addAddress("localhost"); MySocketInterceptor myClientSocketInterceptor = new MySocketInterceptor(false); clientConfig.setSocketInterceptor(myClientSocketInterceptor); HazelcastInstance client = HazelcastClient.newHazelcastClient(clientConfig); for (int i = 0; i < count; i++) { client.getMap("default").put(i, "value" + i); } assertTrue(mySocketInterceptor.getAcceptCallCount() >= 2); assertTrue(mySocketInterceptor.getConnectCallCount() >= 1); assertEquals(1, myClientSocketInterceptor.getConnectCallCount()); assertEquals(0, myClientSocketInterceptor.getAcceptCallCount()); assertEquals(1, mySocketInterceptor.getAcceptFailureCount()); assertEquals(0, myClientSocketInterceptor.getAcceptFailureCount()); assertEquals(1, myClientSocketInterceptor.getConnectFailureCount()); } public static class MySocketInterceptor implements MemberSocketInterceptor { final AtomicInteger initCallCount = new AtomicInteger(); final AtomicInteger acceptCallCount = new AtomicInteger(); final AtomicInteger connectCallCount = new AtomicInteger(); final AtomicInteger acceptFailureCount = new AtomicInteger(); final AtomicInteger connectFailureCount = new AtomicInteger(); final boolean successful; public MySocketInterceptor(boolean successful) { this.successful = successful; } public void init(SocketInterceptorConfig socketInterceptorConfig) { initCallCount.incrementAndGet(); } public void onAccept(Socket acceptedSocket) throws IOException { acceptCallCount.incrementAndGet(); try { OutputStream out = acceptedSocket.getOutputStream(); InputStream in = acceptedSocket.getInputStream(); int loop = new Random().nextInt(2) + 1; int secretValue = 1; int expected = (int) Math.pow(2, loop); for (int i = 0; i < loop; i++) { out.write(secretValue); int read = in.read(); if (read != 2 * secretValue) { throw new IOException("Authentication Failed"); } secretValue = read; } if (secretValue != expected) { throw new IOException("Authentication Failed"); } out.write(0); } catch (IOException e) { acceptFailureCount.incrementAndGet(); throw e; } } public void onConnect(Socket connectedSocket) throws IOException { connectCallCount.incrementAndGet(); try { OutputStream out = connectedSocket.getOutputStream(); InputStream in = connectedSocket.getInputStream(); int multiplyBy = (successful) ? 2 : 1; while (true) { int read = in.read(); if (read == 0) return; out.write(read * multiplyBy); } } catch (IOException e) { connectFailureCount.incrementAndGet(); throw e; } } public int getInitCallCount() { return initCallCount.get(); } public int getAcceptCallCount() { return acceptCallCount.get(); } public int getConnectCallCount() { return connectCallCount.get(); } public int getAcceptFailureCount() { return acceptFailureCount.get(); } public int getConnectFailureCount() { return connectFailureCount.get(); } @Override public String toString() { return "MySocketInterceptor{" + "initCallCount=" + initCallCount + ", acceptCallCount=" + acceptCallCount + ", connectCallCount=" + connectCallCount + ", acceptFailureCount=" + acceptFailureCount + ", connectFailureCount=" + connectFailureCount + '}'; } } }