/*
* Copyright (C) 2015 The Async HBase Authors. All rights reserved.
* This file is part of Async HBase.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* - Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
* - Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
* - Neither the name of the StumbleUpon nor the names of its contributors
* may be used to endorse or promote products derived from this software
* without specific prior written permission.
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/
package org.hbase.async;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.anyMap;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.security.PrivilegedExceptionAction;
import javax.security.auth.Subject;
import javax.security.sasl.SaslClient;
import org.hbase.async.auth.ClientAuthProvider;
import org.hbase.async.auth.KerberosClientAuthProvider;
import org.hbase.async.auth.Login;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.buffer.ChannelBuffers;
import org.jboss.netty.channel.Channel;
import org.junit.Before;
import org.junit.runner.RunWith;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
@RunWith(PowerMockRunner.class)
@PowerMockIgnore({"javax.management.*", "javax.xml.*",
"ch.qos.*", "org.slf4j.*",
"com.sum.*", "org.xml.*"})
@PrepareForTest({ HBaseClient.class, Login.class, RegionClient.class,
SaslClient.class, KerberosClientAuthProvider.class, SecureRpcHelper.class,
Subject.class })
public class BaseTestSecureRpcHelper {
protected static byte[] unwrapped_payload =
{ 'p', 't', 'r', 'a', 'c', 'i' };
protected static byte[] wrapped_payload =
{ 0, 0, 0, 10, 0, 0, 0, 6, 'p', 't', 'r', 'a', 'c', 'i'};
protected HBaseClient client;
protected Config config;
protected RegionClient region_client;
protected SocketAddress remote_endpoint;
protected KerberosClientAuthProvider kerberos_provider;
protected SaslClient sasl_client;
@SuppressWarnings("unchecked")
@Before
public void before() throws Exception {
config = new Config();
client = mock(HBaseClient.class);
region_client = mock(RegionClient.class);
remote_endpoint = new InetSocketAddress("127.0.0.1", 50512);
kerberos_provider = mock(KerberosClientAuthProvider.class);
sasl_client = mock(SaslClient.class);
when(client.getConfig()).thenReturn(config);
PowerMockito.whenNew(KerberosClientAuthProvider.class).withAnyArguments()
.thenReturn(kerberos_provider);
when(kerberos_provider.newSaslClient(anyString(), anyMap()))
.thenReturn(sasl_client);
}
/**
* Super basic implementation of the SecureRpcHelper for unit testing
*/
protected class UTHelper extends SecureRpcHelper {
Channel chan;
ChannelBuffer buffer;
public UTHelper(final HBaseClient hbase_client, final RegionClient region_client,
final SocketAddress remote_endpoint) {
super(hbase_client, region_client, remote_endpoint);
}
@Override
public void sendHello(final Channel channel) {
chan = channel;
}
@Override
public ChannelBuffer handleResponse(ChannelBuffer buf, Channel chan) {
this.chan = chan;
buffer = buf;
return buf;
}
byte[] doProcessChallenge(final byte[] b) {
return processChallenge(b);
}
ClientAuthProvider getProvider() {
return client_auth_provider;
}
boolean useWrap() {
return use_wrap;
}
String getHostIP() {
return host_ip;
}
SaslClient getSaslClient() {
return sasl_client;
}
}
/**
* Prepends a byte array with it's length and creates a wrapped channel buffer
* @param payload The payload to wrap
* @return A channel buffer for testing
*/
protected ChannelBuffer getBuffer(final byte[] payload) {
final byte[] buf = new byte[payload.length + 4];
System.arraycopy(payload, 0, buf, 4, payload.length);
Bytes.setInt(buf, payload.length);
return ChannelBuffers.wrappedBuffer(buf);
}
/**
* Helper to unwrap a wrapped buffer, pretending the sasl client simply
* prepends the length.
* @throws Exception Exception it really shouldn't. Really.
*/
protected void setupUnwrap() throws Exception {
// TODO - figure out a way to use real wrapping. For now we just stick on
// two bytes or take em off.
when(sasl_client.unwrap(any(byte[].class), anyInt(), anyInt()))
.thenAnswer(new Answer<byte[]>() {
@Override
public byte[] answer(final InvocationOnMock invocation)
throws Throwable {
final byte[] buffer = (byte[])invocation.getArguments()[0];
final int length = (Integer)invocation.getArguments()[2];
final byte[] unwrapped = new byte[length - 4];
System.arraycopy(buffer, 4, unwrapped, 0, length - 4);
return unwrapped;
}
});
}
/**
* Helper to wrap a buffer, pretending the sasl client simply prepends the
* length.
* @throws Exception it really shouldn't. Really.
*/
protected void setupWrap() throws Exception {
when(sasl_client.wrap(any(byte[].class), anyInt(), anyInt()))
.thenAnswer(new Answer<byte[]>() {
@Override
public byte[] answer(final InvocationOnMock invocation)
throws Throwable {
final byte[] buffer = (byte[])invocation.getArguments()[0];
final int length = (Integer)invocation.getArguments()[2];
final byte[] wrapped = new byte[length + 4];
System.arraycopy(buffer, 0, wrapped, 4, length);
Bytes.setInt(wrapped, length);
return wrapped;
}
});
}
@SuppressWarnings("unchecked")
protected void setupChallenge() throws Exception {
PowerMockito.mockStatic(Subject.class);
PowerMockito.doAnswer(new Answer<byte[]>() {
@Override
public byte[] answer(final InvocationOnMock invocation) throws Throwable {
final PrivilegedExceptionAction<byte[]> cb =
(PrivilegedExceptionAction<byte[]>)invocation.getArguments()[1];
return cb.run();
}
}).when(Subject.class);
Subject.doAs(any(Subject.class), any(PrivilegedExceptionAction.class));
}
}