/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.ipc;
import static org.junit.Assert.*;
import java.io.IOException;
import java.net.InetSocketAddress;
import org.apache.commons.logging.*;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.ipc.RPC.VersionIncompatible;
import org.apache.hadoop.net.NetUtils;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
/** Unit test for supporting across-version RPCs. */
public class TestRPCCompatibility {
private static final String ADDRESS = "0.0.0.0";
private static InetSocketAddress addr;
private static Server server;
private VersionedProtocol proxy;
public static final Log LOG =
LogFactory.getLog(TestRPCCompatibility.class);
private static Configuration conf = new Configuration();
public interface TestProtocol0 extends VersionedProtocol {
public static final long versionID = 0L;
void ping() throws IOException;
}
public interface TestProtocol1 extends TestProtocol0 {
public static final long versionID = 1L;
String echo(String value) throws IOException;
}
public interface TestProtocol2 extends TestProtocol1 {
public static final long versionID = 2L;
int add(int v1, int v2);
}
public interface TestProtocol3 extends TestProtocol2 {
public static final long versionID = 3L;
int echo(int value) throws IOException;
}
public static class TestImpl implements TestProtocol2 {
int fastPingCounter = 0;
@Override
public long getProtocolVersion(String protocol, long clientVersion)
throws RPC.VersionIncompatible, IOException {
// Although version 0 is compatible but it is too old
// so disallow this version of client
if (clientVersion == TestProtocol0.versionID ) {
throw new RPC.VersionIncompatible(
this.getClass().getName(), clientVersion, versionID);
}
return TestProtocol2.versionID;
}
@Override
public String echo(String value) { return value; }
@Override
public int add(int v1, int v2) { return v1 + v2; }
@Override
public void ping() { return; }
}
@BeforeClass
public static void setup() throws IOException {
// create a server with two handlers
server = RPC.getServer(new TestImpl(), ADDRESS, 0, 2, false, conf);
server.start();
addr = NetUtils.getConnectAddress(server);
}
@AfterClass
public static void tearDown() throws IOException {
if (server != null) {
server.stop();
}
}
@After
public void shutdownProxy() {
if (proxy != null) {
RPC.stopProxy(proxy);
}
}
@Test
public void testIncompatibleOldClient() throws Exception {
try {
proxy = RPC.getProxy(
TestProtocol1.class, TestProtocol0.versionID, addr, conf);
fail("Should not be able to connect to the server");
} catch (RemoteException re) {
assertEquals(RPC.VersionIncompatible.class.getName(), re.getClassName());
}
}
@Test
public void testCompatibleOldClient() throws Exception {
try {
proxy = RPC.getProxy(
TestProtocol1.class, TestProtocol1.versionID, addr, conf);
fail("Expect to get a version mismatch exception");
} catch(RPC.VersionMismatch e) {
assertEquals(TestProtocol2.versionID, e.getServerVersion());
proxy = e.getProxy();
}
TestProtocol1 proxy1 = (TestProtocol1)proxy;
assertEquals("hello", proxy1.echo("hello")); // test equal
}
@Test
public void testEqualVersionClient() throws Exception {
proxy = RPC.getProxy(
TestProtocol2.class, TestProtocol2.versionID, addr, conf);
TestProtocol2 proxy2 = (TestProtocol2)proxy;
assertEquals(3, proxy2.add(1, 2));
assertEquals("hello", proxy2.echo("hello"));
proxy2.ping();
}
private class Version3Client implements TestProtocol3 {
private TestProtocol3 proxy3;
private long serverVersion = versionID;
private Version3Client() throws IOException {
try {
proxy = RPC.getProxy(
TestProtocol3.class, TestProtocol3.versionID, addr, conf);
} catch (RPC.VersionMismatch e) {
serverVersion = e.getServerVersion();
if (serverVersion != TestProtocol2.versionID) {
throw new RPC.VersionIncompatible(TestProtocol3.class.getName(),
versionID, serverVersion);
}
proxy = e.getProxy();
}
proxy3 = (TestProtocol3) proxy;
}
@Override
public int echo(int value) throws IOException, NumberFormatException {
if (serverVersion == versionID) { // same version
return proxy3.echo(value); // use version 3 echo int
} else { // server is version 2
return Integer.parseInt(proxy3.echo(String.valueOf(value)));
}
}
@Override
public int add(int v1, int v2) {
// TODO Auto-generated method stub
return proxy3.add(v1, v2);
}
@Override
public String echo(String value) throws IOException {
return proxy3.echo(value);
}
@Override
public void ping() throws IOException {
proxy3.ping();
}
@Override
public long getProtocolVersion(String protocol, long clientVersion)
throws VersionIncompatible, IOException {
return versionID;
}
}
@Test
public void testCompatibleNewClient() throws Exception {
Version3Client client = new Version3Client();
assertEquals(3, client.add(1, 2));
assertEquals("hello", client.echo("hello"));
assertEquals(3, client.echo(3));
client.ping();
}
}