/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.registration;
import org.apache.flink.runtime.registration.RetryingRegistrationTest.TestRegistrationSuccess;
import org.apache.flink.runtime.rpc.RpcService;
import org.apache.flink.runtime.rpc.TestingRpcService;
import org.apache.flink.util.TestLogger;
import org.junit.Test;
import org.slf4j.LoggerFactory;
import java.util.UUID;
import java.util.concurrent.Executor;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyLong;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* Tests for RegisteredRpcConnection, validating the successful, failure and close behavior.
*/
public class RegisteredRpcConnectionTest extends TestLogger {
@Test
public void testSuccessfulRpcConnection() throws Exception {
final String testRpcConnectionEndpointAddress = "<TestRpcConnectionEndpointAddress>";
final UUID leaderId = UUID.randomUUID();
final String connectionID = "Test RPC Connection ID";
// an endpoint that immediately returns success
TestRegistrationGateway testGateway = new TestRegistrationGateway(new RetryingRegistrationTest.TestRegistrationSuccess(connectionID));
TestingRpcService rpcService = new TestingRpcService();
try {
rpcService.registerGateway(testRpcConnectionEndpointAddress, testGateway);
TestRpcConnection connection = new TestRpcConnection(testRpcConnectionEndpointAddress, leaderId, rpcService.getExecutor(), rpcService);
connection.start();
//wait for connection established
Thread.sleep(RetryingRegistrationTest.TestRetryingRegistration.MAX_TIMEOUT);
// validate correct invocation and result
assertTrue(connection.isConnected());
assertEquals(testRpcConnectionEndpointAddress, connection.getTargetAddress());
assertEquals(leaderId, connection.getTargetLeaderId());
assertEquals(testGateway, connection.getTargetGateway());
assertEquals(connectionID, connection.getConnectionId());
}
finally {
testGateway.stop();
rpcService.stopService();
}
}
@Test
public void testRpcConnectionFailures() throws Exception {
final String connectionFailureMessage = "Test RPC Connection failure";
final String testRpcConnectionEndpointAddress = "<TestRpcConnectionEndpointAddress>";
final UUID leaderId = UUID.randomUUID();
TestingRpcService rpcService = new TestingRpcService();
try {
// gateway that upon calls Throw an exception
TestRegistrationGateway testGateway = mock(TestRegistrationGateway.class);
when(testGateway.registrationCall(any(UUID.class), anyLong())).thenThrow(
new RuntimeException(connectionFailureMessage));
rpcService.registerGateway(testRpcConnectionEndpointAddress, testGateway);
TestRpcConnection connection = new TestRpcConnection(testRpcConnectionEndpointAddress, leaderId, rpcService.getExecutor(), rpcService);
connection.start();
//wait for connection failure
Thread.sleep(RetryingRegistrationTest.TestRetryingRegistration.MAX_TIMEOUT);
// validate correct invocation and result
assertFalse(connection.isConnected());
assertEquals(testRpcConnectionEndpointAddress, connection.getTargetAddress());
assertEquals(leaderId, connection.getTargetLeaderId());
assertNull(connection.getTargetGateway());
assertEquals(connectionFailureMessage, connection.getFailareMessage());
}
finally {
rpcService.stopService();
}
}
@Test
public void testRpcConnectionClose() throws Exception {
final String testRpcConnectionEndpointAddress = "<TestRpcConnectionEndpointAddress>";
final UUID leaderId = UUID.randomUUID();
final String connectionID = "Test RPC Connection ID";
TestRegistrationGateway testGateway = new TestRegistrationGateway(new RetryingRegistrationTest.TestRegistrationSuccess(connectionID));
TestingRpcService rpcService = new TestingRpcService();
try{
rpcService.registerGateway(testRpcConnectionEndpointAddress, testGateway);
TestRpcConnection connection = new TestRpcConnection(testRpcConnectionEndpointAddress, leaderId, rpcService.getExecutor(), rpcService);
connection.start();
//close the connection
connection.close();
// validate connection is closed
assertEquals(testRpcConnectionEndpointAddress, connection.getTargetAddress());
assertEquals(leaderId, connection.getTargetLeaderId());
assertTrue(connection.isClosed());
}
finally {
testGateway.stop();
rpcService.stopService();
}
}
// ------------------------------------------------------------------------
// test RegisteredRpcConnection
// ------------------------------------------------------------------------
private static class TestRpcConnection extends RegisteredRpcConnection<TestRegistrationGateway, TestRegistrationSuccess> {
private final RpcService rpcService;
private String connectionId;
private String failureMessage;
public TestRpcConnection(String targetAddress,
UUID targetLeaderId,
Executor executor,
RpcService rpcService)
{
super(LoggerFactory.getLogger(RegisteredRpcConnectionTest.class), targetAddress, targetLeaderId, executor);
this.rpcService = rpcService;
}
@Override
protected RetryingRegistration<TestRegistrationGateway, RetryingRegistrationTest.TestRegistrationSuccess> generateRegistration() {
return new RetryingRegistrationTest.TestRetryingRegistration(rpcService, getTargetAddress(), getTargetLeaderId());
}
@Override
protected void onRegistrationSuccess(RetryingRegistrationTest.TestRegistrationSuccess success) {
connectionId = success.getCorrelationId();
}
@Override
protected void onRegistrationFailure(Throwable failure) {
failureMessage = failure.getMessage();
}
public String getConnectionId() {
return connectionId;
}
public String getFailareMessage() {
return failureMessage;
}
}
}