/* * Copyright (C) 2015 The Async HBase Authors. All rights reserved. * This file is part of Async HBase. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * - Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * - Neither the name of the StumbleUpon nor the names of its contributors * may be used to endorse or promote products derived from this software * without specific prior written permission. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ package org.hbase.async.auth; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyMap; import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import java.security.Principal; import java.security.PrivilegedExceptionAction; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; import javax.security.auth.Subject; import javax.security.auth.callback.Callback; import javax.security.auth.callback.CallbackHandler; import javax.security.auth.callback.NameCallback; import javax.security.auth.callback.PasswordCallback; import javax.security.auth.callback.UnsupportedCallbackException; import javax.security.auth.login.LoginException; import javax.security.sasl.AuthorizeCallback; import javax.security.sasl.RealmCallback; import javax.security.sasl.Sasl; import javax.security.sasl.SaslClient; import javax.security.sasl.SaslException; import org.apache.zookeeper.server.auth.KerberosName; import org.hbase.async.Config; import org.hbase.async.HBaseClient; import org.hbase.async.auth.KerberosClientAuthProvider.ClientCallbackHandler; import org.jboss.netty.util.HashedWheelTimer; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; @RunWith(PowerMockRunner.class) @PowerMockIgnore({"javax.management.*", "javax.xml.*", "ch.qos.*", "org.slf4j.*", "com.sum.*", "org.xml.*"}) @PrepareForTest({ HBaseClient.class, Login.class, Subject.class, Sasl.class, SaslClient.class, KerberosName.class, KerberosClientAuthProvider.class }) public class TestKerberosClientAuthProvider { private HBaseClient client; private Config config; private Login login; private Subject subject; private SaslClient sasl_client; private Set<Principal> principals; private Principal principal; private KerberosName kerberos_name; // written when the callback is run to create a new SaslClient private String mechanism; private String service_name; private String service_hostname; private Map<String, String> properties; @SuppressWarnings("unchecked") @Before public void before() throws Exception { config = new Config(); client = mock(HBaseClient.class); login = mock(Login.class); subject = mock(Subject.class); sasl_client = mock(SaslClient.class); principal = mock(Principal.class); kerberos_name = mock(KerberosName.class); config.overrideConfig(KerberosClientAuthProvider.PRINCIPAL_KEY, "ephebe"); when(client.getConfig()).thenReturn(config); when(login.getSubject()).thenReturn(subject); principals = new HashSet<Principal>(); principals.add(principal); when(subject.getPrincipals()).thenReturn(principals); PowerMockito.whenNew(KerberosName.class).withAnyArguments() .thenReturn(kerberos_name); when(kerberos_name.toString()).thenReturn("Aching"); when(kerberos_name.getServiceName()).thenReturn("feegle"); when(kerberos_name.getHostName()).thenReturn("ephebe"); PowerMockito.mockStatic(Login.class); PowerMockito.when(Login.getCurrentLogin()).thenReturn(login); PowerMockito.mockStatic(Sasl.class); PowerMockito.when(Sasl.createSaslClient(any(String[].class), anyString(), anyString(), anyString(), anyMap(), any(CallbackHandler.class))) .thenAnswer(new Answer<SaslClient>() { @Override public SaslClient answer(final InvocationOnMock invocation) throws Throwable { mechanism = ((String[])invocation.getArguments()[0])[0]; service_name = (String)invocation.getArguments()[2]; service_hostname = (String)invocation.getArguments()[3]; properties = (Map<String, String>)invocation.getArguments()[4]; return sasl_client; } }); PowerMockito.mockStatic(Subject.class); PowerMockito.doAnswer(new Answer<SaslClient>() { @Override public SaslClient answer(final InvocationOnMock invocation) throws Throwable { final PrivilegedExceptionAction<SaslClient> cb = (PrivilegedExceptionAction<SaslClient>)invocation.getArguments()[1]; return cb.run(); } }).when(Subject.class); Subject.doAs(eq(subject), any(PrivilegedExceptionAction.class)); } @Test public void ctor() throws Exception { final KerberosClientAuthProvider provider = new KerberosClientAuthProvider(client); assertEquals("Aching", provider.getClientUsername()); } @Test (expected = IllegalStateException.class) public void ctorLoginFailure() throws Exception { PowerMockito.doThrow(new LoginException("Boo!")).when(Login.class); Login.initUserIfNeeded(any(Config.class), any(HashedWheelTimer.class), anyString(), any(ClientCallbackHandler.class)); new KerberosClientAuthProvider(client); } @Test (expected = RuntimeException.class) public void ctorOtherException() throws Exception { PowerMockito.doThrow(new RuntimeException("Boo!")).when(Login.class); Login.initUserIfNeeded(any(Config.class), any(HashedWheelTimer.class), anyString(), any(ClientCallbackHandler.class)); new KerberosClientAuthProvider(client); } @Test (expected = NullPointerException.class) public void nullClient() throws Exception { new KerberosClientAuthProvider(null); } @Test public void getAuthMethodCode() throws Exception { final KerberosClientAuthProvider provider = new KerberosClientAuthProvider(client); assertEquals(ClientAuthProvider.KEBEROS_CLIENT_AUTH_CODE, provider.getAuthMethodCode()); } @Test public void newSaslClient() throws Exception { final KerberosClientAuthProvider provider = new KerberosClientAuthProvider(client); final Map<String, String> props = new HashMap<String, String>(0); final SaslClient new_client = provider.newSaslClient("localhost", props); assertTrue(sasl_client == new_client); assertEquals("GSSAPI", mechanism); assertEquals("ephebe", service_hostname); assertEquals("feegle", service_name); assertTrue(properties == props); } @SuppressWarnings("unchecked") @Test (expected = IllegalStateException.class) public void newSaslClientFailedSubject() throws Exception { PowerMockito.doThrow(new RuntimeException("Boo!")).when(Subject.class); Subject.doAs(eq(subject), any(PrivilegedExceptionAction.class)); final KerberosClientAuthProvider provider = new KerberosClientAuthProvider(client); final Map<String, String> props = new HashMap<String, String>(0); provider.newSaslClient("localhost", props); } @SuppressWarnings("unchecked") @Test (expected = IllegalStateException.class) public void newSaslClientFailedCreation() throws Exception { PowerMockito.mockStatic(Sasl.class); PowerMockito.when(Sasl.createSaslClient(any(String[].class), anyString(), anyString(), anyString(), anyMap(), any(CallbackHandler.class))) .thenThrow(new SaslException("Boo!")); final KerberosClientAuthProvider provider = new KerberosClientAuthProvider(client); final Map<String, String> props = new HashMap<String, String>(0); provider.newSaslClient("localhost", props); } @Test public void clientCallbackHandlerName() throws Exception { final Callback[] callbacks = new Callback[1]; final NameCallback callback = new NameCallback("Enter a name", "Ogg"); callbacks[0] = callback; assertNull(callback.getName()); final ClientCallbackHandler cch = new ClientCallbackHandler(null); cch.handle(callbacks); assertEquals("Ogg", callback.getName()); } @Test public void clientCallbackHandlerNameNullDefault() throws Exception { final Callback[] callbacks = new Callback[1]; final NameCallback callback = new NameCallback("Enter a name"); callbacks[0] = callback; assertNull(callback.getName()); final ClientCallbackHandler cch = new ClientCallbackHandler(null); cch.handle(callbacks); assertNull(callback.getName()); } @Test public void clientCallbackHandlerPassword() throws Exception { final Callback[] callbacks = new Callback[1]; final PasswordCallback callback = new PasswordCallback("Gimme a password", false); callbacks[0] = callback; assertNull(callback.getPassword()); final ClientCallbackHandler cch = new ClientCallbackHandler("Adora Belle"); cch.handle(callbacks); assertArrayEquals("Adora Belle".toCharArray(), callback.getPassword()); } @Test public void clientCallbackHandlerPasswordNoPassword() throws Exception { final Callback[] callbacks = new Callback[1]; final PasswordCallback callback = new PasswordCallback("Gimme a password", false); callbacks[0] = callback; assertNull(callback.getPassword()); final ClientCallbackHandler cch = new ClientCallbackHandler(null); cch.handle(callbacks); assertNull(callback.getPassword()); } @Test public void clientCallbackHandlerRealm() throws Exception { final Callback[] callbacks = new Callback[1]; final RealmCallback callback = new RealmCallback("Gimme a realm", "Buggarup"); callbacks[0] = callback; assertNull(callback.getText()); final ClientCallbackHandler cch = new ClientCallbackHandler(null); cch.handle(callbacks); assertEquals("Buggarup", callback.getText()); } @Test public void clientCallbackHandlerRealmNullDefault() throws Exception { final Callback[] callbacks = new Callback[1]; final RealmCallback callback = new RealmCallback("Gimme a realm"); callbacks[0] = callback; assertNull(callback.getText()); final ClientCallbackHandler cch = new ClientCallbackHandler(null); cch.handle(callbacks); assertNull(callback.getText()); } @Test public void clientCallbackHandlerAuthorize() throws Exception { final Callback[] callbacks = new Callback[1]; final AuthorizeCallback callback = new AuthorizeCallback("Roland", "Roland"); callbacks[0] = callback; assertFalse(callback.isAuthorized()); assertNull(callback.getAuthorizedID()); final ClientCallbackHandler cch = new ClientCallbackHandler(null); cch.handle(callbacks); assertTrue(callback.isAuthorized()); assertEquals("Roland", callback.getAuthorizedID()); } @Test public void clientCallbackHandlerAuthorizeNoMatch() throws Exception { final Callback[] callbacks = new Callback[1]; final AuthorizeCallback callback = new AuthorizeCallback("Dean", "Ridcully"); callbacks[0] = callback; assertFalse(callback.isAuthorized()); assertNull(callback.getAuthorizedID()); final ClientCallbackHandler cch = new ClientCallbackHandler(null); cch.handle(callbacks); assertFalse(callback.isAuthorized()); assertNull(callback.getAuthorizedID()); } @Test (expected = UnsupportedCallbackException.class) public void clientCallbackHandlerUnrecognized() throws Exception { final Callback[] callbacks = new Callback[1]; callbacks[0] = new UnknownCallback(); final ClientCallbackHandler cch = new ClientCallbackHandler(null); cch.handle(callbacks); } @Test public void clientCallbackHandlerEmptyCallbacks() throws Exception { // shouldn't ever happen, but who knows? final Callback[] callbacks = new Callback[0]; final ClientCallbackHandler cch = new ClientCallbackHandler(null); cch.handle(callbacks); } @Test (expected = NullPointerException.class) public void clientCallbackHandlerNullCallbacks() throws Exception { // shouldn't ever happen, but who knows? final ClientCallbackHandler cch = new ClientCallbackHandler(null); cch.handle(null); } static class UnknownCallback implements Callback { // just a dummy class } }