package com.twitter.common.security.unittest;
import java.net.InetAddress;
import java.net.NetworkInterface;
import java.net.SocketException;
import java.net.UnknownHostException;
import java.security.Permission;
import java.util.Collections;
import java.util.Iterator;
import java.util.Set;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Function;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterators;
import com.google.common.net.InetAddresses;
/**
* A {@link SecurityManager} designed to provide secure environment for unit tests.
*/
public class UnitTestSecurityManager extends SecurityManager {
private static final String LOCALHOST = "localhost";
private static String getLocalHostName() {
try {
return InetAddress.getLocalHost().getHostName().toLowerCase();
} catch (UnknownHostException e) {
return ""; // this node doesn't not have name
}
}
private static Set<InetAddress> getMyAddresses() {
try {
return ImmutableSet.copyOf(Iterators.concat(Iterators.transform(
Iterators.forEnumeration(NetworkInterface.getNetworkInterfaces()),
new Function<NetworkInterface, Iterator<InetAddress>>() {
@Override public Iterator<InetAddress> apply(NetworkInterface iface) {
return Iterators.forEnumeration(iface.getInetAddresses());
}
})));
} catch (SocketException e) {
return Collections.emptySet();
}
}
private final String myName;
private final Set<InetAddress> myAddresses;
/**
* To construct this class, caller needs to have NetPermission("getNetworkInformation")
*/
public UnitTestSecurityManager() {
this(getLocalHostName(), getMyAddresses());
}
@VisibleForTesting
UnitTestSecurityManager(String name, Set<InetAddress> addresses) {
myName = name;
myAddresses = addresses;
}
@Override
public void checkConnect(String host, int port) {
validateHost(host);
}
@Override
public void checkConnect(String host, int port, Object context) {
validateHost(host);
}
@Override
public void checkPermission(Permission perm) {
// no-op; permit any action
}
@Override
public void checkPermission(Permission perm, Object context) {
// no-op; permit any action
}
/**
* Check if:
* <ul>
* <li>host is "localhost" or this machine name, or
* <li>host is valid IP address (not hostname), and
* <ul>
* <li>loopback address, or
* <li>one of the addresses assigned to this host.
* </ul>
* </ul>
*
* throw {@link SecurityException} if not.
*/
private void validateHost(String host) {
if (LOCALHOST.equalsIgnoreCase(host) || myName.equalsIgnoreCase(host)) {
return;
}
String message = String.format("Connecting to %s is blocked by %s.",
host, this.getClass().getSimpleName());
// check if "host" represents IP address, not a machine name which is handled above.
if (!InetAddresses.isInetAddress(host)) {
throw new SecurityException(message);
}
InetAddress addr = InetAddresses.forString(host);
if (addr.isAnyLocalAddress()
|| addr.isLoopbackAddress()
|| myAddresses.contains(addr)) {
return;
}
throw new SecurityException(message);
}
}