/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.thrift.transport; import java.io.IOException; import java.util.HashMap; import java.util.Map; import javax.security.auth.callback.Callback; import javax.security.auth.callback.CallbackHandler; import javax.security.auth.callback.NameCallback; import javax.security.auth.callback.PasswordCallback; import javax.security.auth.callback.UnsupportedCallbackException; import javax.security.sasl.AuthorizeCallback; import javax.security.sasl.RealmCallback; import javax.security.sasl.Sasl; import javax.security.sasl.SaslClient; import javax.security.sasl.SaslClientFactory; import javax.security.sasl.SaslException; import javax.security.sasl.SaslServer; import javax.security.sasl.SaslServerFactory; import junit.framework.TestCase; import org.apache.thrift.TProcessor; import org.apache.thrift.protocol.TProtocolFactory; import org.apache.thrift.server.ServerTestBase; import org.apache.thrift.server.TServer; import org.apache.thrift.server.TSimpleServer; import org.apache.thrift.server.TServer.Args; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class TestTSaslTransports extends TestCase { private static final Logger LOGGER = LoggerFactory.getLogger(TestTSaslTransports.class); private static final String HOST = "localhost"; private static final String SERVICE = "thrift-test"; private static final String PRINCIPAL = "thrift-test-principal"; private static final String PASSWORD = "super secret password"; private static final String REALM = "thrift-test-realm"; private static final String UNWRAPPED_MECHANISM = "CRAM-MD5"; private static final Map<String, String> UNWRAPPED_PROPS = null; private static final String WRAPPED_MECHANISM = "DIGEST-MD5"; private static final Map<String, String> WRAPPED_PROPS = new HashMap<String, String>(); static { WRAPPED_PROPS.put(Sasl.QOP, "auth-int"); WRAPPED_PROPS.put("com.sun.security.sasl.digest.realm", REALM); } private static final String testMessage1 = "Hello, world! Also, four " + "score and seven years ago our fathers brought forth on this " + "continent a new nation, conceived in liberty, and dedicated to the " + "proposition that all men are created equal."; private static final String testMessage2 = "I have a dream that one day " + "this nation will rise up and live out the true meaning of its creed: " + "'We hold these truths to be self-evident, that all men are created equal.'"; private static class TestSaslCallbackHandler implements CallbackHandler { private final String password; public TestSaslCallbackHandler(String password) { this.password = password; } @Override public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { for (Callback c : callbacks) { if (c instanceof NameCallback) { ((NameCallback) c).setName(PRINCIPAL); } else if (c instanceof PasswordCallback) { ((PasswordCallback) c).setPassword(password.toCharArray()); } else if (c instanceof AuthorizeCallback) { ((AuthorizeCallback) c).setAuthorized(true); } else if (c instanceof RealmCallback) { ((RealmCallback) c).setText(REALM); } else { throw new UnsupportedCallbackException(c); } } } } private class ServerThread extends Thread { final String mechanism; final Map<String, String> props; volatile Throwable thrown; public ServerThread(String mechanism, Map<String, String> props) { this.mechanism = mechanism; this.props = props; } public void run() { try { internalRun(); } catch (Throwable t) { thrown = t; } } private void internalRun() throws Exception { TServerSocket serverSocket = new TServerSocket(ServerTestBase.PORT); try { acceptAndWrite(serverSocket); } finally { serverSocket.close(); } } private void acceptAndWrite(TServerSocket serverSocket) throws Exception { TTransport serverTransport = serverSocket.accept(); TTransport saslServerTransport = new TSaslServerTransport( mechanism, SERVICE, HOST, props, new TestSaslCallbackHandler(PASSWORD), serverTransport); saslServerTransport.open(); byte[] inBuf = new byte[testMessage1.getBytes().length]; // Deliberately read less than the full buffer to ensure // that TSaslTransport is correctly buffering reads. This // will fail for the WRAPPED test, if it doesn't work. saslServerTransport.readAll(inBuf, 0, 5); saslServerTransport.readAll(inBuf, 5, 10); saslServerTransport.readAll(inBuf, 15, inBuf.length - 15); LOGGER.debug("server got: {}", new String(inBuf)); assertEquals(new String(inBuf), testMessage1); LOGGER.debug("server writing: {}", testMessage2); saslServerTransport.write(testMessage2.getBytes()); saslServerTransport.flush(); saslServerTransport.close(); } } private void testSaslOpen(final String mechanism, final Map<String, String> props) throws Exception { ServerThread serverThread = new ServerThread(mechanism, props); serverThread.start(); try { Thread.sleep(1000); } catch (InterruptedException e) { // Ah well. } try { TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT); TTransport saslClientTransport = new TSaslClientTransport(mechanism, PRINCIPAL, SERVICE, HOST, props, new TestSaslCallbackHandler(PASSWORD), clientSocket); saslClientTransport.open(); LOGGER.debug("client writing: {}", testMessage1); saslClientTransport.write(testMessage1.getBytes()); saslClientTransport.flush(); byte[] inBuf = new byte[testMessage2.getBytes().length]; saslClientTransport.readAll(inBuf, 0, inBuf.length); LOGGER.debug("client got: {}", new String(inBuf)); assertEquals(new String(inBuf), testMessage2); TTransportException expectedException = null; try { saslClientTransport.open(); } catch (TTransportException e) { expectedException = e; } assertNotNull(expectedException); saslClientTransport.close(); } catch (Exception e) { LOGGER.warn("Exception caught", e); throw e; } finally { serverThread.interrupt(); try { serverThread.join(); } catch (InterruptedException e) { // Ah well. } assertNull(serverThread.thrown); } } public void testUnwrappedOpen() throws Exception { testSaslOpen(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS); } public void testWrappedOpen() throws Exception { testSaslOpen(WRAPPED_MECHANISM, WRAPPED_PROPS); } public void testAnonymousOpen() throws Exception { testSaslOpen("ANONYMOUS", null); } /** * Test that we get the proper exceptions thrown back the server when * the client provides invalid password. */ public void testBadPassword() throws Exception { ServerThread serverThread = new ServerThread(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS); serverThread.start(); try { Thread.sleep(1000); } catch (InterruptedException e) { // Ah well. } boolean clientSidePassed = true; try { TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT); TTransport saslClientTransport = new TSaslClientTransport( UNWRAPPED_MECHANISM, PRINCIPAL, SERVICE, HOST, UNWRAPPED_PROPS, new TestSaslCallbackHandler("NOT THE PASSWORD"), clientSocket); saslClientTransport.open(); clientSidePassed = false; fail("Was able to open transport with bad password"); } catch (TTransportException tte) { LOGGER.error("Exception for bad password", tte); assertNotNull(tte.getMessage()); assertTrue(tte.getMessage().contains("Invalid response")); } finally { serverThread.interrupt(); serverThread.join(); if (clientSidePassed) { assertNotNull(serverThread.thrown); assertTrue(serverThread.thrown.getMessage().contains("Invalid response")); } } } public void testWithServer() throws Exception { new TestTSaslTransportsWithServer().testIt(); } private static class TestTSaslTransportsWithServer extends ServerTestBase { private Thread serverThread; private TServer server; @Override public TTransport getClientTransport(TTransport underlyingTransport) throws Exception { return new TSaslClientTransport( WRAPPED_MECHANISM, PRINCIPAL, SERVICE, HOST, WRAPPED_PROPS, new TestSaslCallbackHandler(PASSWORD), underlyingTransport); } @Override public void startServer(final TProcessor processor, final TProtocolFactory protoFactory) throws Exception { serverThread = new Thread() { public void run() { try { // Transport TServerSocket socket = new TServerSocket(PORT); TTransportFactory factory = new TSaslServerTransport.Factory( WRAPPED_MECHANISM, SERVICE, HOST, WRAPPED_PROPS, new TestSaslCallbackHandler(PASSWORD)); server = new TSimpleServer(new Args(socket).processor(processor).transportFactory(factory).protocolFactory(protoFactory)); // Run it LOGGER.debug("Starting the server on port {}", PORT); server.serve(); } catch (Exception e) { e.printStackTrace(); fail(); } } }; serverThread.start(); Thread.sleep(1000); } @Override public void stopServer() throws Exception { server.stop(); try { serverThread.join(); } catch (InterruptedException e) {} } } /** * Implementation of SASL ANONYMOUS, used for testing client-side * initial responses. */ private static class AnonymousClient implements SaslClient { private final String username; private boolean hasProvidedInitialResponse; public AnonymousClient(String username) { this.username = username; } public String getMechanismName() { return "ANONYMOUS"; } public boolean hasInitialResponse() { return true; } public byte[] evaluateChallenge(byte[] challenge) throws SaslException { if (hasProvidedInitialResponse) { throw new SaslException("Already complete!"); } try { hasProvidedInitialResponse = true; return username.getBytes("UTF-8"); } catch (IOException e) { throw new SaslException(e.toString()); } } public boolean isComplete() { return hasProvidedInitialResponse; } public byte[] unwrap(byte[] incoming, int offset, int len) { throw new UnsupportedOperationException(); } public byte[] wrap(byte[] outgoing, int offset, int len) { throw new UnsupportedOperationException(); } public Object getNegotiatedProperty(String propName) { return null; } public void dispose() {} } private static class AnonymousServer implements SaslServer { private String user; public String getMechanismName() { return "ANONYMOUS"; } public byte[] evaluateResponse(byte[] response) throws SaslException { try { this.user = new String(response, "UTF-8"); } catch (IOException e) { throw new SaslException(e.toString()); } return null; } public boolean isComplete() { return user != null; } public String getAuthorizationID() { return user; } public byte[] unwrap(byte[] incoming, int offset, int len) { throw new UnsupportedOperationException(); } public byte[] wrap(byte[] outgoing, int offset, int len) { throw new UnsupportedOperationException(); } public Object getNegotiatedProperty(String propName) { return null; } public void dispose() {} } public static class SaslAnonymousFactory implements SaslClientFactory, SaslServerFactory { public SaslClient createSaslClient( String[] mechanisms, String authorizationId, String protocol, String serverName, Map<String,?> props, CallbackHandler cbh) { for (String mech : mechanisms) { if ("ANONYMOUS".equals(mech)) { return new AnonymousClient(authorizationId); } } return null; } public SaslServer createSaslServer( String mechanism, String protocol, String serverName, Map<String,?> props, CallbackHandler cbh) { if ("ANONYMOUS".equals(mechanism)) { return new AnonymousServer(); } return null; } public String[] getMechanismNames(Map<String, ?> props) { return new String[] { "ANONYMOUS" }; } } static { java.security.Security.addProvider(new SaslAnonymousProvider()); } public static class SaslAnonymousProvider extends java.security.Provider { public SaslAnonymousProvider() { super("ThriftSaslAnonymous", 1.0, "Thrift Anonymous SASL provider"); put("SaslClientFactory.ANONYMOUS", SaslAnonymousFactory.class.getName()); put("SaslServerFactory.ANONYMOUS", SaslAnonymousFactory.class.getName()); } } }