/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.query.netty;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
import org.apache.flink.runtime.query.KvStateID;
import org.apache.flink.runtime.query.KvStateRegistry;
import org.apache.flink.runtime.query.KvStateServerAddress;
import org.apache.flink.runtime.query.netty.message.KvStateRequest;
import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
import org.apache.flink.runtime.query.netty.message.KvStateRequestType;
import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
import org.apache.flink.runtime.state.AbstractStateBackend;
import org.apache.flink.runtime.state.internal.InternalKvState;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.VoidNamespace;
import org.apache.flink.runtime.state.VoidNamespaceSerializer;
import org.apache.flink.runtime.state.memory.MemoryStateBackend;
import org.apache.flink.util.NetUtils;
import org.junit.AfterClass;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.concurrent.Await;
import scala.concurrent.Future;
import scala.concurrent.duration.Deadline;
import scala.concurrent.duration.FiniteDuration;
import java.net.ConnectException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.nio.channels.ClosedChannelException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
public class KvStateClientTest {
private static final Logger LOG = LoggerFactory.getLogger(KvStateClientTest.class);
// Thread pool for client bootstrap (shared between tests)
private static final NioEventLoopGroup NIO_GROUP = new NioEventLoopGroup();
private final static FiniteDuration TEST_TIMEOUT = new FiniteDuration(100, TimeUnit.SECONDS);
@AfterClass
public static void tearDown() throws Exception {
if (NIO_GROUP != null) {
NIO_GROUP.shutdownGracefully();
}
}
/**
* Tests simple queries, of which half succeed and half fail.
*/
@Test
public void testSimpleRequests() throws Exception {
Deadline deadline = TEST_TIMEOUT.fromNow();
AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats();
KvStateClient client = null;
Channel serverChannel = null;
try {
client = new KvStateClient(1, stats);
// Random result
final byte[] expected = new byte[1024];
ThreadLocalRandom.current().nextBytes(expected);
final LinkedBlockingQueue<ByteBuf> received = new LinkedBlockingQueue<>();
final AtomicReference<Channel> channel = new AtomicReference<>();
serverChannel = createServerChannel(new ChannelInboundHandlerAdapter() {
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
channel.set(ctx.channel());
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
received.add((ByteBuf) msg);
}
});
KvStateServerAddress serverAddress = getKvStateServerAddress(serverChannel);
List<Future<byte[]>> futures = new ArrayList<>();
int numQueries = 1024;
for (int i = 0; i < numQueries; i++) {
futures.add(client.getKvState(serverAddress, new KvStateID(), new byte[0]));
}
// Respond to messages
Exception testException = new RuntimeException("Expected test Exception");
for (int i = 0; i < numQueries; i++) {
ByteBuf buf = received.poll(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
assertNotNull("Receive timed out", buf);
Channel ch = channel.get();
assertNotNull("Channel not active", ch);
assertEquals(KvStateRequestType.REQUEST, KvStateRequestSerializer.deserializeHeader(buf));
KvStateRequest request = KvStateRequestSerializer.deserializeKvStateRequest(buf);
buf.release();
if (i % 2 == 0) {
ByteBuf response = KvStateRequestSerializer.serializeKvStateRequestResult(
serverChannel.alloc(),
request.getRequestId(),
expected);
ch.writeAndFlush(response);
} else {
ByteBuf response = KvStateRequestSerializer.serializeKvStateRequestFailure(
serverChannel.alloc(),
request.getRequestId(),
testException);
ch.writeAndFlush(response);
}
}
for (int i = 0; i < numQueries; i++) {
if (i % 2 == 0) {
byte[] serializedResult = Await.result(futures.get(i), deadline.timeLeft());
assertArrayEquals(expected, serializedResult);
} else {
try {
Await.result(futures.get(i), deadline.timeLeft());
fail("Did not throw expected Exception");
} catch (RuntimeException ignored) {
// Expected
}
}
}
assertEquals(numQueries, stats.getNumRequests());
int expectedRequests = numQueries / 2;
// Counts can take some time to propagate
while (deadline.hasTimeLeft() && (stats.getNumSuccessful() != expectedRequests ||
stats.getNumFailed() != expectedRequests)) {
Thread.sleep(100);
}
assertEquals(expectedRequests, stats.getNumSuccessful());
assertEquals(expectedRequests, stats.getNumFailed());
} finally {
if (client != null) {
client.shutDown();
}
if (serverChannel != null) {
serverChannel.close();
}
assertEquals("Channel leak", 0, stats.getNumConnections());
}
}
/**
* Tests that a request to an unavailable host is failed with ConnectException.
*/
@Test
public void testRequestUnavailableHost() throws Exception {
Deadline deadline = TEST_TIMEOUT.fromNow();
AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats();
KvStateClient client = null;
try {
client = new KvStateClient(1, stats);
int availablePort = NetUtils.getAvailablePort();
KvStateServerAddress serverAddress = new KvStateServerAddress(
InetAddress.getLocalHost(),
availablePort);
Future<byte[]> future = client.getKvState(serverAddress, new KvStateID(), new byte[0]);
try {
Await.result(future, deadline.timeLeft());
fail("Did not throw expected ConnectException");
} catch (ConnectException ignored) {
// Expected
}
} finally {
if (client != null) {
client.shutDown();
}
assertEquals("Channel leak", 0, stats.getNumConnections());
}
}
/**
* Multiple threads concurrently fire queries.
*/
@Test
public void testConcurrentQueries() throws Exception {
Deadline deadline = TEST_TIMEOUT.fromNow();
AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats();
ExecutorService executor = null;
KvStateClient client = null;
Channel serverChannel = null;
final byte[] serializedResult = new byte[1024];
ThreadLocalRandom.current().nextBytes(serializedResult);
try {
int numQueryTasks = 4;
final int numQueriesPerTask = 1024;
executor = Executors.newFixedThreadPool(numQueryTasks);
client = new KvStateClient(1, stats);
serverChannel = createServerChannel(new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
ByteBuf buf = (ByteBuf) msg;
assertEquals(KvStateRequestType.REQUEST, KvStateRequestSerializer.deserializeHeader(buf));
KvStateRequest request = KvStateRequestSerializer.deserializeKvStateRequest(buf);
buf.release();
ByteBuf response = KvStateRequestSerializer.serializeKvStateRequestResult(
ctx.alloc(),
request.getRequestId(),
serializedResult);
ctx.channel().writeAndFlush(response);
}
});
final KvStateServerAddress serverAddress = getKvStateServerAddress(serverChannel);
final KvStateClient finalClient = client;
Callable<List<Future<byte[]>>> queryTask = new Callable<List<Future<byte[]>>>() {
@Override
public List<Future<byte[]>> call() throws Exception {
List<Future<byte[]>> results = new ArrayList<>(numQueriesPerTask);
for (int i = 0; i < numQueriesPerTask; i++) {
results.add(finalClient.getKvState(
serverAddress,
new KvStateID(),
new byte[0]));
}
return results;
}
};
// Submit query tasks
List<java.util.concurrent.Future<List<Future<byte[]>>>> futures = new ArrayList<>();
for (int i = 0; i < numQueryTasks; i++) {
futures.add(executor.submit(queryTask));
}
// Verify results
for (java.util.concurrent.Future<List<Future<byte[]>>> future : futures) {
List<Future<byte[]>> results = future.get(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
for (Future<byte[]> result : results) {
byte[] actual = Await.result(result, deadline.timeLeft());
assertArrayEquals(serializedResult, actual);
}
}
int totalQueries = numQueryTasks * numQueriesPerTask;
// Counts can take some time to propagate
while (deadline.hasTimeLeft() && stats.getNumSuccessful() != totalQueries) {
Thread.sleep(100);
}
assertEquals(totalQueries, stats.getNumRequests());
assertEquals(totalQueries, stats.getNumSuccessful());
} finally {
if (executor != null) {
executor.shutdown();
}
if (serverChannel != null) {
serverChannel.close();
}
if (client != null) {
client.shutDown();
}
assertEquals("Channel leak", 0, stats.getNumConnections());
}
}
/**
* Tests that a server failure closes the connection and removes it from
* the established connections.
*/
@Test
public void testFailureClosesChannel() throws Exception {
Deadline deadline = TEST_TIMEOUT.fromNow();
AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats();
KvStateClient client = null;
Channel serverChannel = null;
try {
client = new KvStateClient(1, stats);
final LinkedBlockingQueue<ByteBuf> received = new LinkedBlockingQueue<>();
final AtomicReference<Channel> channel = new AtomicReference<>();
serverChannel = createServerChannel(new ChannelInboundHandlerAdapter() {
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
channel.set(ctx.channel());
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
received.add((ByteBuf) msg);
}
});
KvStateServerAddress serverAddress = getKvStateServerAddress(serverChannel);
// Requests
List<Future<byte[]>> futures = new ArrayList<>();
futures.add(client.getKvState(serverAddress, new KvStateID(), new byte[0]));
futures.add(client.getKvState(serverAddress, new KvStateID(), new byte[0]));
ByteBuf buf = received.poll(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
assertNotNull("Receive timed out", buf);
buf.release();
buf = received.poll(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
assertNotNull("Receive timed out", buf);
buf.release();
assertEquals(1, stats.getNumConnections());
Channel ch = channel.get();
assertNotNull("Channel not active", ch);
// Respond with failure
ch.writeAndFlush(KvStateRequestSerializer.serializeServerFailure(
serverChannel.alloc(),
new RuntimeException("Expected test server failure")));
try {
Await.result(futures.remove(0), deadline.timeLeft());
fail("Did not throw expected server failure");
} catch (RuntimeException ignored) {
// Expected
}
try {
Await.result(futures.remove(0), deadline.timeLeft());
fail("Did not throw expected server failure");
} catch (RuntimeException ignored) {
// Expected
}
assertEquals(0, stats.getNumConnections());
// Counts can take some time to propagate
while (deadline.hasTimeLeft() && (stats.getNumSuccessful() != 0 ||
stats.getNumFailed() != 2)) {
Thread.sleep(100);
}
assertEquals(2, stats.getNumRequests());
assertEquals(0, stats.getNumSuccessful());
assertEquals(2, stats.getNumFailed());
} finally {
if (client != null) {
client.shutDown();
}
if (serverChannel != null) {
serverChannel.close();
}
assertEquals("Channel leak", 0, stats.getNumConnections());
}
}
/**
* Tests that a server channel close, closes the connection and removes it
* from the established connections.
*/
@Test
public void testServerClosesChannel() throws Exception {
Deadline deadline = TEST_TIMEOUT.fromNow();
AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats();
KvStateClient client = null;
Channel serverChannel = null;
try {
client = new KvStateClient(1, stats);
final AtomicBoolean received = new AtomicBoolean();
final AtomicReference<Channel> channel = new AtomicReference<>();
serverChannel = createServerChannel(new ChannelInboundHandlerAdapter() {
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
channel.set(ctx.channel());
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
received.set(true);
}
});
KvStateServerAddress serverAddress = getKvStateServerAddress(serverChannel);
// Requests
Future<byte[]> future = client.getKvState(serverAddress, new KvStateID(), new byte[0]);
while (!received.get() && deadline.hasTimeLeft()) {
Thread.sleep(50);
}
assertTrue("Receive timed out", received.get());
assertEquals(1, stats.getNumConnections());
channel.get().close().await(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
try {
Await.result(future, deadline.timeLeft());
fail("Did not throw expected server failure");
} catch (ClosedChannelException ignored) {
// Expected
}
assertEquals(0, stats.getNumConnections());
// Counts can take some time to propagate
while (deadline.hasTimeLeft() && (stats.getNumSuccessful() != 0 ||
stats.getNumFailed() != 1)) {
Thread.sleep(100);
}
assertEquals(1, stats.getNumRequests());
assertEquals(0, stats.getNumSuccessful());
assertEquals(1, stats.getNumFailed());
} finally {
if (client != null) {
client.shutDown();
}
if (serverChannel != null) {
serverChannel.close();
}
assertEquals("Channel leak", 0, stats.getNumConnections());
}
}
/**
* Tests multiple clients querying multiple servers until 100k queries have
* been processed. At this point, the client is shut down and its verified
* that all ongoing requests are failed.
*/
@Test
public void testClientServerIntegration() throws Exception {
// Config
final int numServers = 2;
final int numServerEventLoopThreads = 2;
final int numServerQueryThreads = 2;
final int numClientEventLoopThreads = 4;
final int numClientsTasks = 8;
final int batchSize = 16;
final int numKeyGroups = 1;
AbstractStateBackend abstractBackend = new MemoryStateBackend();
KvStateRegistry dummyRegistry = new KvStateRegistry();
DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0);
dummyEnv.setKvStateRegistry(dummyRegistry);
AbstractKeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend(
dummyEnv,
new JobID(),
"test_op",
IntSerializer.INSTANCE,
numKeyGroups,
new KeyGroupRange(0, 0),
dummyRegistry.createTaskRegistry(new JobID(), new JobVertexID()));
final FiniteDuration timeout = new FiniteDuration(10, TimeUnit.SECONDS);
AtomicKvStateRequestStats clientStats = new AtomicKvStateRequestStats();
KvStateClient client = null;
ExecutorService clientTaskExecutor = null;
final KvStateServer[] server = new KvStateServer[numServers];
try {
client = new KvStateClient(numClientEventLoopThreads, clientStats);
clientTaskExecutor = Executors.newFixedThreadPool(numClientsTasks);
// Create state
ValueStateDescriptor<Integer> desc = new ValueStateDescriptor<>("any", IntSerializer.INSTANCE);
desc.setQueryable("any");
// Create servers
KvStateRegistry[] registry = new KvStateRegistry[numServers];
AtomicKvStateRequestStats[] serverStats = new AtomicKvStateRequestStats[numServers];
final KvStateID[] ids = new KvStateID[numServers];
for (int i = 0; i < numServers; i++) {
registry[i] = new KvStateRegistry();
serverStats[i] = new AtomicKvStateRequestStats();
server[i] = new KvStateServer(
InetAddress.getLocalHost(),
0,
numServerEventLoopThreads,
numServerQueryThreads,
registry[i],
serverStats[i]);
server[i].start();
backend.setCurrentKey(1010 + i);
// Value per server
ValueState<Integer> state = backend.getPartitionedState(VoidNamespace.INSTANCE,
VoidNamespaceSerializer.INSTANCE,
desc);
state.update(201 + i);
// we know it must be a KvStat but this is not exposed to the user via State
InternalKvState<?> kvState = (InternalKvState<?>) state;
// Register KvState (one state instance for all server)
ids[i] = registry[i].registerKvState(new JobID(), new JobVertexID(), new KeyGroupRange(0, 0), "any", kvState);
}
final KvStateClient finalClient = client;
Callable<Void> queryTask = new Callable<Void>() {
@Override
public Void call() throws Exception {
while (true) {
if (Thread.interrupted()) {
throw new InterruptedException();
}
// Random server permutation
List<Integer> random = new ArrayList<>();
for (int j = 0; j < batchSize; j++) {
random.add(j);
}
Collections.shuffle(random);
// Dispatch queries
List<Future<byte[]>> futures = new ArrayList<>(batchSize);
for (int j = 0; j < batchSize; j++) {
int targetServer = random.get(j) % numServers;
byte[] serializedKeyAndNamespace = KvStateRequestSerializer.serializeKeyAndNamespace(
1010 + targetServer,
IntSerializer.INSTANCE,
VoidNamespace.INSTANCE,
VoidNamespaceSerializer.INSTANCE);
futures.add(finalClient.getKvState(
server[targetServer].getAddress(),
ids[targetServer],
serializedKeyAndNamespace));
}
// Verify results
for (int j = 0; j < batchSize; j++) {
int targetServer = random.get(j) % numServers;
Future<byte[]> future = futures.get(j);
byte[] buf = Await.result(future, timeout);
int value = KvStateRequestSerializer.deserializeValue(buf, IntSerializer.INSTANCE);
assertEquals(201 + targetServer, value);
}
}
}
};
// Submit tasks
List<java.util.concurrent.Future<Void>> taskFutures = new ArrayList<>();
for (int i = 0; i < numClientsTasks; i++) {
taskFutures.add(clientTaskExecutor.submit(queryTask));
}
long numRequests;
while ((numRequests = clientStats.getNumRequests()) < 100_000) {
Thread.sleep(100);
LOG.info("Number of requests {}/100_000", numRequests);
}
// Shut down
client.shutDown();
for (java.util.concurrent.Future<Void> future : taskFutures) {
try {
future.get();
fail("Did not throw expected Exception after shut down");
} catch (ExecutionException t) {
if (t.getCause() instanceof ClosedChannelException ||
t.getCause() instanceof IllegalStateException) {
// Expected
} else {
t.printStackTrace();
fail("Failed with unexpected Exception type: " + t.getClass().getName());
}
}
}
assertEquals("Connection leak (client)", 0, clientStats.getNumConnections());
for (int i = 0; i < numServers; i++) {
boolean success = false;
int numRetries = 0;
while (!success) {
try {
assertEquals("Connection leak (server)", 0, serverStats[i].getNumConnections());
success = true;
} catch (Throwable t) {
if (numRetries < 10) {
LOG.info("Retrying connection leak check (server)");
Thread.sleep((numRetries + 1) * 50);
numRetries++;
} else {
throw t;
}
}
}
}
} finally {
if (client != null) {
client.shutDown();
}
for (int i = 0; i < numServers; i++) {
if (server[i] != null) {
server[i].shutDown();
}
}
if (clientTaskExecutor != null) {
clientTaskExecutor.shutdown();
}
}
}
// ------------------------------------------------------------------------
private Channel createServerChannel(final ChannelHandler... handlers) throws UnknownHostException, InterruptedException {
ServerBootstrap bootstrap = new ServerBootstrap()
// Bind address and port
.localAddress(InetAddress.getLocalHost(), 0)
// NIO server channels
.group(NIO_GROUP)
.channel(NioServerSocketChannel.class)
// See initializer for pipeline details
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
ch.pipeline()
.addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
.addLast(handlers);
}
});
return bootstrap.bind().sync().channel();
}
private KvStateServerAddress getKvStateServerAddress(Channel serverChannel) {
InetSocketAddress localAddress = (InetSocketAddress) serverChannel.localAddress();
return new KvStateServerAddress(localAddress.getAddress(), localAddress.getPort());
}
}