/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.hadoop.ipc; import static org.apache.hadoop.fs.CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHENTICATION; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; import java.lang.annotation.Annotation; import java.net.InetSocketAddress; import java.security.PrivilegedExceptionAction; import java.util.Collection; import java.util.Set; import javax.security.sasl.Sasl; import junit.framework.Assert; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.impl.Log4JLogger; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.CommonConfigurationKeys; import org.apache.hadoop.io.Text; import org.apache.hadoop.ipc.Client.ConnectionId; import org.apache.hadoop.net.NetUtils; import org.apache.hadoop.security.KerberosInfo; import org.apache.hadoop.security.SaslInputStream; import org.apache.hadoop.security.SaslRpcClient; import org.apache.hadoop.security.SaslRpcServer; import org.apache.hadoop.security.SecurityInfo; import org.apache.hadoop.security.SecurityUtil; import org.apache.hadoop.security.TestUserGroupInformation; import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod; import org.apache.hadoop.security.token.SecretManager; import org.apache.hadoop.security.token.Token; import org.apache.hadoop.security.token.TokenIdentifier; import org.apache.hadoop.security.token.TokenInfo; import org.apache.hadoop.security.token.TokenSelector; import org.apache.hadoop.security.token.SecretManager.InvalidToken; import org.apache.log4j.Level; import org.junit.Test; /** Unit tests for using Sasl over RPC. */ public class TestSaslRPC { private static final String ADDRESS = "0.0.0.0"; public static final Log LOG = LogFactory.getLog(TestSaslRPC.class); static final String ERROR_MESSAGE = "Token is invalid"; static final String SERVER_PRINCIPAL_KEY = "test.ipc.server.principal"; static final String SERVER_KEYTAB_KEY = "test.ipc.server.keytab"; static final String SERVER_PRINCIPAL_1 = "p1/foo@BAR"; static final String SERVER_PRINCIPAL_2 = "p2/foo@BAR"; private static Configuration conf; static { conf = new Configuration(); conf.set(HADOOP_SECURITY_AUTHENTICATION, "kerberos"); UserGroupInformation.setConfiguration(conf); } static { ((Log4JLogger) Client.LOG).getLogger().setLevel(Level.ALL); ((Log4JLogger) Server.LOG).getLogger().setLevel(Level.ALL); ((Log4JLogger) SaslRpcClient.LOG).getLogger().setLevel(Level.ALL); ((Log4JLogger) SaslRpcServer.LOG).getLogger().setLevel(Level.ALL); ((Log4JLogger) SaslInputStream.LOG).getLogger().setLevel(Level.ALL); ((Log4JLogger) SecurityUtil.LOG).getLogger().setLevel(Level.ALL); } public static class TestTokenIdentifier extends TokenIdentifier { private Text tokenid; private Text realUser; final static Text KIND_NAME = new Text("test.token"); public TestTokenIdentifier() { this(new Text(), new Text()); } public TestTokenIdentifier(Text tokenid) { this(tokenid, new Text()); } public TestTokenIdentifier(Text tokenid, Text realUser) { this.tokenid = tokenid == null ? new Text() : tokenid; this.realUser = realUser == null ? new Text() : realUser; } @Override public Text getKind() { return KIND_NAME; } @Override public UserGroupInformation getUser() { if ("".equals(realUser.toString())) { return UserGroupInformation.createRemoteUser(tokenid.toString()); } else { UserGroupInformation realUgi = UserGroupInformation .createRemoteUser(realUser.toString()); return UserGroupInformation .createProxyUser(tokenid.toString(), realUgi); } } @Override public void readFields(DataInput in) throws IOException { tokenid.readFields(in); realUser.readFields(in); } @Override public void write(DataOutput out) throws IOException { tokenid.write(out); realUser.write(out); } } public static class TestTokenSecretManager extends SecretManager<TestTokenIdentifier> { public byte[] createPassword(TestTokenIdentifier id) { return id.getBytes(); } public byte[] retrievePassword(TestTokenIdentifier id) throws InvalidToken { return id.getBytes(); } public TestTokenIdentifier createIdentifier() { return new TestTokenIdentifier(); } } public static class BadTokenSecretManager extends TestTokenSecretManager { public byte[] retrievePassword(TestTokenIdentifier id) throws InvalidToken { throw new InvalidToken(ERROR_MESSAGE); } } public static class TestTokenSelector implements TokenSelector<TestTokenIdentifier> { @SuppressWarnings("unchecked") @Override public Token<TestTokenIdentifier> selectToken(Text service, Collection<Token<? extends TokenIdentifier>> tokens) { if (service == null) { return null; } for (Token<? extends TokenIdentifier> token : tokens) { if (TestTokenIdentifier.KIND_NAME.equals(token.getKind()) && service.equals(token.getService())) { return (Token<TestTokenIdentifier>) token; } } return null; } } @KerberosInfo( serverPrincipal = SERVER_PRINCIPAL_KEY) @TokenInfo(TestTokenSelector.class) public interface TestSaslProtocol extends TestRPC.TestProtocol { public AuthenticationMethod getAuthMethod() throws IOException; } public static class TestSaslImpl extends TestRPC.TestImpl implements TestSaslProtocol { public AuthenticationMethod getAuthMethod() throws IOException { return UserGroupInformation.getCurrentUser().getAuthenticationMethod(); } } public static class CustomSecurityInfo extends SecurityInfo { @Override public KerberosInfo getKerberosInfo(Class<?> protocol, Configuration conf) { return new KerberosInfo() { @Override public Class<? extends Annotation> annotationType() { return null; } @Override public String serverPrincipal() { return SERVER_PRINCIPAL_KEY; } @Override public String clientPrincipal() { return null; } }; } @Override public TokenInfo getTokenInfo(Class<?> protocol, Configuration conf) { return new TokenInfo() { @Override public Class<? extends TokenSelector<? extends TokenIdentifier>> value() { return TestTokenSelector.class; } @Override public Class<? extends Annotation> annotationType() { return null; } }; } } @Test public void testDigestRpc() throws Exception { TestTokenSecretManager sm = new TestTokenSecretManager(); final Server server = RPC.getServer(TestSaslProtocol.class, new TestSaslImpl(), ADDRESS, 0, 5, true, conf, sm); doDigestRpc(server, sm); } @Test public void testDigestRpcWithoutAnnotation() throws Exception { TestTokenSecretManager sm = new TestTokenSecretManager(); try { SecurityUtil.setSecurityInfoProviders(new CustomSecurityInfo()); final Server server = RPC.getServer(TestSaslProtocol.class, new TestSaslImpl(), ADDRESS, 0, 5, true, conf, sm); doDigestRpc(server, sm); } finally { SecurityUtil.setSecurityInfoProviders(new SecurityInfo[0]); } } @Test public void testSecureToInsecureRpc() throws Exception { Server server = RPC.getServer(TestSaslProtocol.class, new TestSaslImpl(), ADDRESS, 0, 5, true, conf, null); server.disableSecurity(); TestTokenSecretManager sm = new TestTokenSecretManager(); doDigestRpc(server, sm); } @Test public void testErrorMessage() throws Exception { BadTokenSecretManager sm = new BadTokenSecretManager(); final Server server = RPC.getServer(TestSaslProtocol.class, new TestSaslImpl(), ADDRESS, 0, 5, true, conf, sm); boolean succeeded = false; try { doDigestRpc(server, sm); } catch (RemoteException e) { LOG.info("LOGGING MESSAGE: " + e.getLocalizedMessage()); assertTrue(ERROR_MESSAGE.equals(e.getLocalizedMessage())); assertTrue(e.unwrapRemoteException() instanceof InvalidToken); succeeded = true; } assertTrue(succeeded); } private void doDigestRpc(Server server, TestTokenSecretManager sm ) throws Exception { server.start(); final UserGroupInformation current = UserGroupInformation.getCurrentUser(); final InetSocketAddress addr = NetUtils.getConnectAddress(server); TestTokenIdentifier tokenId = new TestTokenIdentifier(new Text(current .getUserName())); Token<TestTokenIdentifier> token = new Token<TestTokenIdentifier>(tokenId, sm); SecurityUtil.setTokenService(token, addr); current.addToken(token); TestSaslProtocol proxy = null; try { proxy = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class, TestSaslProtocol.versionID, addr, conf); //QOP must be auth Assert.assertEquals(SaslRpcServer.SASL_PROPS.get(Sasl.QOP), "auth"); proxy.ping(); } finally { server.stop(); if (proxy != null) { RPC.stopProxy(proxy); } } } @Test public void testPingInterval() throws Exception { Configuration newConf = new Configuration(conf); newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_1); conf.setInt(CommonConfigurationKeys.IPC_PING_INTERVAL_KEY, CommonConfigurationKeys.IPC_PING_INTERVAL_DEFAULT); // set doPing to true newConf.setBoolean(CommonConfigurationKeys.IPC_CLIENT_PING_KEY, true); ConnectionId remoteId = ConnectionId.getConnectionId( new InetSocketAddress(0), TestSaslProtocol.class, null, 0, newConf); assertEquals(CommonConfigurationKeys.IPC_PING_INTERVAL_DEFAULT, remoteId.getPingInterval()); // set doPing to false newConf.setBoolean(CommonConfigurationKeys.IPC_CLIENT_PING_KEY, false); remoteId = ConnectionId.getConnectionId( new InetSocketAddress(0), TestSaslProtocol.class, null, 0, newConf); assertEquals(0, remoteId.getPingInterval()); } @Test public void testGetRemotePrincipal() throws Exception { try { Configuration newConf = new Configuration(conf); newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_1); ConnectionId remoteId = ConnectionId.getConnectionId( new InetSocketAddress(0), TestSaslProtocol.class, null, 0, newConf); assertEquals(SERVER_PRINCIPAL_1, remoteId.getServerPrincipal()); // this following test needs security to be off newConf.set(HADOOP_SECURITY_AUTHENTICATION, "simple"); UserGroupInformation.setConfiguration(newConf); remoteId = ConnectionId.getConnectionId(new InetSocketAddress(0), TestSaslProtocol.class, null, 0, newConf); assertEquals( "serverPrincipal should be null when security is turned off", null, remoteId.getServerPrincipal()); } finally { // revert back to security is on UserGroupInformation.setConfiguration(conf); } } @Test public void testPerConnectionConf() throws Exception { TestTokenSecretManager sm = new TestTokenSecretManager(); final Server server = RPC.getServer(TestSaslProtocol.class, new TestSaslImpl(), ADDRESS, 0, 5, true, conf, sm); server.start(); final UserGroupInformation current = UserGroupInformation.getCurrentUser(); final InetSocketAddress addr = NetUtils.getConnectAddress(server); TestTokenIdentifier tokenId = new TestTokenIdentifier(new Text(current .getUserName())); Token<TestTokenIdentifier> token = new Token<TestTokenIdentifier>(tokenId, sm); SecurityUtil.setTokenService(token, addr); current.addToken(token); Configuration newConf = new Configuration(conf); newConf.set("hadoop.rpc.socket.factory.class.default", ""); newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_1); TestSaslProtocol proxy1 = null; TestSaslProtocol proxy2 = null; TestSaslProtocol proxy3 = null; try { proxy1 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class, TestSaslProtocol.versionID, addr, newConf); proxy1.getAuthMethod(); Client client = WritableRpcEngine.getClient(conf); Set<ConnectionId> conns = client.getConnectionIds(); assertEquals("number of connections in cache is wrong", 1, conns.size()); // same conf, connection should be re-used proxy2 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class, TestSaslProtocol.versionID, addr, newConf); proxy2.getAuthMethod(); assertEquals("number of connections in cache is wrong", 1, conns.size()); // different conf, new connection should be set up newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_2); proxy3 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class, TestSaslProtocol.versionID, addr, newConf); proxy3.getAuthMethod(); ConnectionId[] connsArray = conns.toArray(new ConnectionId[0]); assertEquals("number of connections in cache is wrong", 2, connsArray.length); String p1 = connsArray[0].getServerPrincipal(); String p2 = connsArray[1].getServerPrincipal(); assertFalse("should have different principals", p1.equals(p2)); assertTrue("principal not as expected", p1.equals(SERVER_PRINCIPAL_1) || p1.equals(SERVER_PRINCIPAL_2)); assertTrue("principal not as expected", p2.equals(SERVER_PRINCIPAL_1) || p2.equals(SERVER_PRINCIPAL_2)); } finally { server.stop(); RPC.stopProxy(proxy1); RPC.stopProxy(proxy2); RPC.stopProxy(proxy3); } } static void testKerberosRpc(String principal, String keytab) throws Exception { final Configuration newConf = new Configuration(conf); newConf.set(SERVER_PRINCIPAL_KEY, principal); newConf.set(SERVER_KEYTAB_KEY, keytab); SecurityUtil.login(newConf, SERVER_KEYTAB_KEY, SERVER_PRINCIPAL_KEY); TestUserGroupInformation.verifyLoginMetrics(1, 0); UserGroupInformation current = UserGroupInformation.getCurrentUser(); System.out.println("UGI: " + current); Server server = RPC.getServer(TestSaslProtocol.class, new TestSaslImpl(), ADDRESS, 0, 5, true, newConf, null); TestSaslProtocol proxy = null; server.start(); InetSocketAddress addr = NetUtils.getConnectAddress(server); try { proxy = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class, TestSaslProtocol.versionID, addr, newConf); proxy.ping(); } finally { server.stop(); if (proxy != null) { RPC.stopProxy(proxy); } } System.out.println("Test is successful."); } @Test public void testDigestAuthMethod() throws Exception { TestTokenSecretManager sm = new TestTokenSecretManager(); Server server = RPC.getServer(TestSaslProtocol.class, new TestSaslImpl(), ADDRESS, 0, 5, true, conf, sm); server.start(); final UserGroupInformation current = UserGroupInformation.getCurrentUser(); final InetSocketAddress addr = NetUtils.getConnectAddress(server); TestTokenIdentifier tokenId = new TestTokenIdentifier(new Text(current .getUserName())); Token<TestTokenIdentifier> token = new Token<TestTokenIdentifier>(tokenId, sm); SecurityUtil.setTokenService(token, addr); current.addToken(token); current.doAs(new PrivilegedExceptionAction<Object>() { public Object run() throws IOException { TestSaslProtocol proxy = null; try { proxy = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class, TestSaslProtocol.versionID, addr, conf); Assert.assertEquals(AuthenticationMethod.TOKEN, proxy.getAuthMethod()); } finally { if (proxy != null) { RPC.stopProxy(proxy); } } return null; } }); server.stop(); } public static void main(String[] args) throws Exception { System.out.println("Testing Kerberos authentication over RPC"); if (args.length != 2) { System.err .println("Usage: java <options> org.apache.hadoop.ipc.TestSaslRPC " + " <serverPrincipal> <keytabFile>"); System.exit(-1); } String principal = args[0]; String keytab = args[1]; testKerberosRpc(principal, keytab); } }