/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hive.spark.client.rpc;
import java.io.Closeable;
import java.net.InetAddress;
import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import javax.security.sasl.SaslException;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.util.concurrent.Future;
import org.apache.commons.io.IOUtils;
import org.apache.hadoop.hive.conf.HiveConf;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import static org.junit.Assert.*;
public class TestRpc {
private static final Logger LOG = LoggerFactory.getLogger(TestRpc.class);
private Collection<Closeable> closeables;
private Map<String, String> emptyConfig =
ImmutableMap.of(HiveConf.ConfVars.SPARK_RPC_CHANNEL_LOG_LEVEL.varname, "DEBUG");
@Before
public void setUp() {
closeables = Lists.newArrayList();
}
@After
public void cleanUp() throws Exception {
for (Closeable c : closeables) {
IOUtils.closeQuietly(c);
}
}
private <T extends Closeable> T autoClose(T closeable) {
closeables.add(closeable);
return closeable;
}
@Test
public void testRpcDispatcher() throws Exception {
Rpc serverRpc = autoClose(Rpc.createEmbedded(new TestDispatcher()));
Rpc clientRpc = autoClose(Rpc.createEmbedded(new TestDispatcher()));
TestMessage outbound = new TestMessage("Hello World!");
Future<TestMessage> call = clientRpc.call(outbound, TestMessage.class);
LOG.debug("Transferring messages...");
transfer(serverRpc, clientRpc);
TestMessage reply = call.get(10, TimeUnit.SECONDS);
assertEquals(outbound.message, reply.message);
}
@Test
public void testClientServer() throws Exception {
RpcServer server = autoClose(new RpcServer(emptyConfig));
Rpc[] rpcs = createRpcConnection(server);
Rpc serverRpc = rpcs[0];
Rpc client = rpcs[1];
TestMessage outbound = new TestMessage("Hello World!");
Future<TestMessage> call = client.call(outbound, TestMessage.class);
TestMessage reply = call.get(10, TimeUnit.SECONDS);
assertEquals(outbound.message, reply.message);
TestMessage another = new TestMessage("Hello again!");
Future<TestMessage> anotherCall = client.call(another, TestMessage.class);
TestMessage anotherReply = anotherCall.get(10, TimeUnit.SECONDS);
assertEquals(another.message, anotherReply.message);
String errorMsg = "This is an error.";
try {
client.call(new ErrorCall(errorMsg)).get(10, TimeUnit.SECONDS);
} catch (ExecutionException ee) {
assertTrue(ee.getCause() instanceof RpcException);
assertTrue(ee.getCause().getMessage().indexOf(errorMsg) >= 0);
}
// Test from server to client too.
TestMessage serverMsg = new TestMessage("Hello from the server!");
Future<TestMessage> serverCall = serverRpc.call(serverMsg, TestMessage.class);
TestMessage serverReply = serverCall.get(10, TimeUnit.SECONDS);
assertEquals(serverMsg.message, serverReply.message);
}
@Test
public void testServerAddress() throws Exception {
String hostAddress = InetAddress.getLocalHost().getHostName();
Map<String, String> config = new HashMap<String, String>();
// Test if rpc_server_address is configured
config.put(HiveConf.ConfVars.SPARK_RPC_SERVER_ADDRESS.varname, hostAddress);
RpcServer server1 = autoClose(new RpcServer(config));
assertTrue("Host address should match the expected one", server1.getAddress() == hostAddress);
// Test if rpc_server_address is not configured but HS2 server host is configured
config.put(HiveConf.ConfVars.SPARK_RPC_SERVER_ADDRESS.varname, "");
config.put(HiveConf.ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST.varname, hostAddress);
RpcServer server2 = autoClose(new RpcServer(config));
assertTrue("Host address should match the expected one", server2.getAddress() == hostAddress);
// Test if both are not configured
config.put(HiveConf.ConfVars.SPARK_RPC_SERVER_ADDRESS.varname, "");
config.put(HiveConf.ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST.varname, "");
RpcServer server3 = autoClose(new RpcServer(config));
assertTrue("Host address should match the expected one", server3.getAddress() == InetAddress.getLocalHost().getHostName());
}
@Test
public void testBadHello() throws Exception {
RpcServer server = autoClose(new RpcServer(emptyConfig));
Future<Rpc> serverRpcFuture = server.registerClient("client", "newClient",
new TestDispatcher());
NioEventLoopGroup eloop = new NioEventLoopGroup();
Future<Rpc> clientRpcFuture = Rpc.createClient(emptyConfig, eloop,
"localhost", server.getPort(), "client", "wrongClient", new TestDispatcher());
try {
autoClose(clientRpcFuture.get(10, TimeUnit.SECONDS));
fail("Should have failed to create client with wrong secret.");
} catch (ExecutionException ee) {
// On failure, the SASL handler will throw an exception indicating that the SASL
// negotiation failed.
assertTrue("Unexpected exception: " + ee.getCause(),
ee.getCause() instanceof SaslException);
}
serverRpcFuture.cancel(true);
}
@Test
public void testServerPort() throws Exception {
Map<String, String> config = new HashMap<String, String>();
RpcServer server0 = new RpcServer(config);
assertTrue("Empty port range should return a random valid port: " + server0.getPort(), server0.getPort() >= 0);
IOUtils.closeQuietly(server0);
config.put(HiveConf.ConfVars.SPARK_RPC_SERVER_PORT.varname, "49152-49222,49223,49224-49333");
RpcServer server1 = new RpcServer(config);
assertTrue("Port should be within configured port range:" + server1.getPort(), server1.getPort() >= 49152 && server1.getPort() <= 49333);
IOUtils.closeQuietly(server1);
int expectedPort = 65535;
config.put(HiveConf.ConfVars.SPARK_RPC_SERVER_PORT.varname, String.valueOf(expectedPort));
RpcServer server2 = new RpcServer(config);
assertTrue("Port should match configured one: " + server2.getPort(), server2.getPort() == expectedPort);
IOUtils.closeQuietly(server2);
config.put(HiveConf.ConfVars.SPARK_RPC_SERVER_PORT.varname, "49552-49222,49223,49224-49333");
try {
autoClose(new RpcServer(config));
assertTrue("Invalid port range should throw an exception", false); // Should not reach here
} catch(IOException e) {
assertEquals("Incorrect RPC server port configuration for HiveServer2", e.getMessage());
}
// Retry logic
expectedPort = 65535;
config.put(HiveConf.ConfVars.SPARK_RPC_SERVER_PORT.varname, String.valueOf(expectedPort) + ",21-23");
RpcServer server3 = new RpcServer(config);
assertTrue("Port should match configured one:" + server3.getPort(), server3.getPort() == expectedPort);
IOUtils.closeQuietly(server3);
}
@Test
public void testCloseListener() throws Exception {
RpcServer server = autoClose(new RpcServer(emptyConfig));
Rpc[] rpcs = createRpcConnection(server);
Rpc client = rpcs[1];
final AtomicInteger closeCount = new AtomicInteger();
client.addListener(new Rpc.Listener() {
@Override
public void rpcClosed(Rpc rpc) {
closeCount.incrementAndGet();
}
});
client.close();
client.close();
assertEquals(1, closeCount.get());
}
@Test
public void testNotDeserializableRpc() throws Exception {
RpcServer server = autoClose(new RpcServer(emptyConfig));
Rpc[] rpcs = createRpcConnection(server);
Rpc client = rpcs[1];
try {
client.call(new NotDeserializable(42)).get(10, TimeUnit.SECONDS);
} catch (ExecutionException ee) {
assertTrue(ee.getCause() instanceof RpcException);
assertTrue(ee.getCause().getMessage().indexOf("KryoException") >= 0);
}
}
@Test
public void testEncryption() throws Exception {
Map<String, String> eConf = ImmutableMap.<String,String>builder()
.putAll(emptyConfig)
.put(RpcConfiguration.RPC_SASL_OPT_PREFIX + "qop", Rpc.SASL_AUTH_CONF)
.build();
RpcServer server = autoClose(new RpcServer(eConf));
Rpc[] rpcs = createRpcConnection(server, eConf);
Rpc client = rpcs[1];
TestMessage outbound = new TestMessage("Hello World!");
Future<TestMessage> call = client.call(outbound, TestMessage.class);
TestMessage reply = call.get(10, TimeUnit.SECONDS);
assertEquals(outbound.message, reply.message);
}
@Test
public void testClientTimeout() throws Exception {
Map<String, String> conf = ImmutableMap.<String,String>builder()
.putAll(emptyConfig)
.build();
RpcServer server = autoClose(new RpcServer(conf));
String secret = server.createSecret();
try {
autoClose(server.registerClient("client", secret, new TestDispatcher(), 1L).get());
fail("Server should have timed out client.");
} catch (ExecutionException ee) {
assertTrue(ee.getCause() instanceof TimeoutException);
}
NioEventLoopGroup eloop = new NioEventLoopGroup();
Future<Rpc> clientRpcFuture = Rpc.createClient(conf, eloop,
"localhost", server.getPort(), "client", secret, new TestDispatcher());
try {
autoClose(clientRpcFuture.get());
fail("Client should have failed to connect to server.");
} catch (ExecutionException ee) {
// Error should not be a timeout.
assertFalse(ee.getCause() instanceof TimeoutException);
}
}
private void transfer(Rpc serverRpc, Rpc clientRpc) {
EmbeddedChannel client = (EmbeddedChannel) clientRpc.getChannel();
EmbeddedChannel server = (EmbeddedChannel) serverRpc.getChannel();
server.runPendingTasks();
client.runPendingTasks();
int count = 0;
while (!client.outboundMessages().isEmpty()) {
server.writeInbound(client.readOutbound());
count++;
}
server.flush();
LOG.debug("Transferred {} outbound client messages.", count);
count = 0;
while (!server.outboundMessages().isEmpty()) {
client.writeInbound(server.readOutbound());
count++;
}
client.flush();
LOG.debug("Transferred {} outbound server messages.", count);
}
/**
* Creates a client connection between the server and a client.
*
* @return two-tuple (server rpc, client rpc)
*/
private Rpc[] createRpcConnection(RpcServer server) throws Exception {
return createRpcConnection(server, emptyConfig);
}
private Rpc[] createRpcConnection(RpcServer server, Map<String, String> clientConf)
throws Exception {
String secret = server.createSecret();
Future<Rpc> serverRpcFuture = server.registerClient("client", secret, new TestDispatcher());
NioEventLoopGroup eloop = new NioEventLoopGroup();
Future<Rpc> clientRpcFuture = Rpc.createClient(clientConf, eloop,
"localhost", server.getPort(), "client", secret, new TestDispatcher());
Rpc serverRpc = autoClose(serverRpcFuture.get(10, TimeUnit.SECONDS));
Rpc clientRpc = autoClose(clientRpcFuture.get(10, TimeUnit.SECONDS));
return new Rpc[] { serverRpc, clientRpc };
}
private static class TestMessage {
final String message;
public TestMessage() {
this(null);
}
public TestMessage(String message) {
this.message = message;
}
}
private static class ErrorCall {
final String error;
public ErrorCall() {
this(null);
}
public ErrorCall(String error) {
this.error = error;
}
}
private static class NotDeserializable {
NotDeserializable(int unused) {
}
}
private static class TestDispatcher extends RpcDispatcher {
protected TestMessage handle(ChannelHandlerContext ctx, TestMessage msg) {
return msg;
}
protected void handle(ChannelHandlerContext ctx, ErrorCall msg) {
throw new IllegalArgumentException(msg.error);
}
protected void handle(ChannelHandlerContext ctx, NotDeserializable msg) {
// No op. Shouldn't actually be called, if it is, the test will fail.
}
}
}