/** * Copyright 2010 Wealthfront Inc. Licensed under the Apache License, * Version 2.0 (the "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law * or agreed to in writing, software distributed under the License is * distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the specific language * governing permissions and limitations under the License. */ package com.kaching.platform.testing; import static java.lang.String.format; import static java.lang.Thread.currentThread; import java.io.FileDescriptor; import java.net.InetAddress; import java.security.Permission; import java.util.List; import java.util.Set; import java.util.concurrent.atomic.AtomicReference; import com.google.common.base.Predicate; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import com.kaching.platform.common.logging.Log; /** * A {@link SecurityManager} to spotlight and minimize IO access while allowing * fine-grained control access to IO resoucres. * * This class was designed to draw attention to any IO (file and network) your * test suite may perform under the hood. IO not only slows down your test * suite, but unit tests that accidentally modify their environment may result * to flakey builds. * * Should a unit test need to perform IO, you may grant fine-grained permission * by annotating the container class with {@link AllowDNSResolution}, * {@link AllowExternalProcess}, {@link AllowLocalFileAccess}, * {@link AllowNetworkAccess}, or {@link AllowNetworkMulticast}. Some of these * annotations allow further refinement via parameters. * * <i>Usage.</i> To use the {@link LessIOSecurityManager}, you must set the * "java.security.manager" system property to * "com.kaching.platform.testing.LessIOSecurityManager", or your subclass. * * <i>Usage via command-line arguments.</i> You may add * "-Djava.security.manager=com.kaching.platform.testing.LessIOSecurityManager" * to your command-line invocation of the JVM to use this class as your * {@link SecurityManager}. * * <i>Usage via Ant.</i> You may declare the "java.security.manager" system * property in the "junit" element of your "build.xml" file. You <b>must</b> set * the "fork" property to ensure a new JVM, with this class as the * {@link SecurityManager} is utilized. * * <pre> * {@code * <junit fork="true"> * <sysproperty key="java.security.manager" value="com.kaching.platform.testing.LessIOSecurityManager" /> * ... * </junit> * } * </pre> * * <i>Performance.</i> Circa late 2010, the {@link LessIOSecurityManager}'s * impact on the performance of our test suite was less than 1.00%. * * @see {@link AllowDNSResolution}, {@link AllowExternalProcess}, * {@link AllowLocalFileAccess}, {@link AllowNetworkAccess}, and * {@link AllowNetworkMulticast} */ public class LessIOSecurityManager extends SecurityManager { private static final Log log = Log.getLog(LessIOSecurityManager.class); protected static final String JAVA_HOME = System.getProperty("java.home"); protected static final String PATH_SEPARATOR = System.getProperty("path.separator"); // Updated at SecurityManager init and again at every ClassLoader init. protected static final AtomicReference<List<String>> CP_PARTS = new AtomicReference<List<String>>(getClassPath()); protected static final String TMP_DIR = System.getProperty("java.io.tmpdir").replaceFirst("/$", ""); private static final Set<Class<?>> whitelistedClasses = ImmutableSet.<Class<?>>of( java.lang.ClassLoader.class, java.net.URLClassLoader.class); private static final int lowestEphemeralPort = Integer.getInteger("kawala.testing.low-ephemeral-port", 32768); private static final int highestEphemeralPort = Integer.getInteger("kawala.testing.high-ephemeral-port", 65535); private static final Set<Integer> allocatedEphemeralPorts = Sets.newSetFromMap(Maps.<Integer, Boolean>newConcurrentMap()); /** * Any subclasses that override this method <b>must</b> include any Class<?> * elements returned by {@link LessIOSecurityManager#getWhitelistedClasses()}. * The recommended pattern is: * <blockquote><pre> * {@code private final Set<Class<?>> whitelistedClasses = ImmutableSet.<Class<?>>builder() .addAll(parentWhitelistedClasses) .add(javax.crypto.Cipher.class) .add(javax.xml.xpath.XPathFactory.class) .build(); protected Set<Class<?>> getWhitelistedClasses() { return whitelistedClasses; } } </pre></blockquote> */ protected Set<Class<?>> getWhitelistedClasses() { return whitelistedClasses; } private final boolean reporting; public LessIOSecurityManager() { this(true); } protected LessIOSecurityManager(boolean reporting) { this.reporting = reporting; } private static ImmutableList<String> getClassPath() { return ImmutableList.copyOf(System.getProperty("java.class.path").split(PATH_SEPARATOR)); } // {{ Allowed only via {@link @AllowNetworkAccess}, {@link @AllowDNSResolution}, or {@link @AllowNetworkMulticast}) protected void checkDNSResolution(Class<?>[] classContext) throws CantDoItException { if (traceWithoutExplicitlyAllowedClass(classContext)) { checkClassContextPermissions(classContext, new Predicate<Class<?>>() { @Override public boolean apply(Class<?> input) { if ((input.getAnnotation(AllowDNSResolution.class) != null) || (input.getAnnotation(AllowNetworkMulticast.class) != null) || (input.getAnnotation(AllowNetworkListen.class) != null) || (input.getAnnotation(AllowNetworkAccess.class) != null)) { return true; } return false; } @Override public String toString() { return String.format("@AllowDNSResolution permission"); } }); } } protected void checkNetworkEndpoint(final String host, final int port, final String description) throws CantDoItException { Class<?>[] classContext = getClassContext(); if (port == -1) { checkDNSResolution(classContext); return; } if (traceWithoutExplicitlyAllowedClass(classContext)) { checkClassContextPermissions(classContext, new Predicate<Class<?>>() { @Override public boolean apply(Class<?> input) { AllowNetworkAccess a = input.getAnnotation(AllowNetworkAccess.class); if (a == null) { return false; } for (String endpoint : a.endpoints()) { String[] parts = endpoint.split(":"); String portAsString = Integer.toString(port); if ((parts[0].equals(host) && parts[1].equals(portAsString)) || (parts[0].equals("*") && parts[1].equals(portAsString)) || (parts[0].equals(host) && parts[1].equals("*")) || (parts[0].equals(host) && parts[1].equals("0") && allocatedEphemeralPorts.contains(port))) { return true; } } return false; } @Override public String toString() { return String.format("@AllowNetworkAccess permission for %s:%d (%s)", host, port, description); } }); } } @Override public void checkAccept(String host, int port) throws CantDoItException { checkNetworkEndpoint(host, port, "accept"); } @Override public void checkConnect(String host, int port, Object context) throws CantDoItException { checkNetworkEndpoint(host, port, "connect"); } @Override public void checkConnect(String host, int port) throws CantDoItException { checkNetworkEndpoint(host, port, "connect"); } @Override public void checkListen(final int port) throws CantDoItException { Class<?>[] classContext = getClassContext(); if (traceWithoutExplicitlyAllowedClass(classContext)) { checkClassContextPermissions(classContext, new Predicate<Class<?>>() { @Override public boolean apply(Class<?> input) { AllowNetworkListen a = input.getAnnotation(AllowNetworkListen.class); if (a == null) { return false; } for (int p : a.ports()) { if (p == 0) { // Check for access to ephemeral ports if (port >= lowestEphemeralPort && port <= highestEphemeralPort) { p = port; allocatedEphemeralPorts.add(port); } } if (p == port) { return true; } } return false; } @Override public String toString() { return String.format("@AllowNetworkListen permission for port %d", port); } }); } } @Override public void checkMulticast(InetAddress maddr) throws CantDoItException { Class<?>[] classContext = getClassContext(); if (traceWithoutExplicitlyAllowedClass(classContext)) { checkClassContextPermissions(classContext, new Predicate<Class<?>>() { @Override public boolean apply(Class<?> input) { AllowNetworkMulticast a = input .getAnnotation(AllowNetworkMulticast.class); if (a != null) { return true; } else { return false; } } @Override public String toString() { return String.format("@AllowNetworkMulticast permission"); } }); } } @Override public void checkMulticast(InetAddress maddr, byte ttl) throws CantDoItException { checkMulticast(maddr); } // }} // {{ Allowed only via {@link @AllowLocalFileAccess} protected void checkFileAccess(final String file, final String description) throws CantDoItException { Class<?>[] classContext = getClassContext(); if (traceWithoutExplicitlyAllowedClass(classContext)) { if (file.startsWith(JAVA_HOME)) { // Files in JAVA_HOME are always allowed return; } // Ant's JUnit task writes to /tmp/junitXXX if (file.startsWith("/dev/random") || file.startsWith("/dev/urandom") || file.startsWith("/tmp/junit")) { return; } /* * Although this is an expensive operation, it needs to be here, in a * suboptimal location to avoid ClassCircularityErrors that can occur when * attempting to load an anonymous class. */ for (String part : CP_PARTS.get()) { if (file.startsWith(part)) { // Files in the CLASSPATH are always allowed return; } } try { checkClassContextPermissions(classContext, new Predicate<Class<?>>() { @Override public boolean apply(Class<?> input) { AllowLocalFileAccess a = input .getAnnotation(AllowLocalFileAccess.class); if (a == null) { return false; } for (String p : a.paths()) { if ((p.equals("*")) || (p.equals(file)) || (p.contains("%TMP_DIR%") && (file.startsWith(p.replaceAll("%TMP_DIR%", TMP_DIR)))) || (p.startsWith("*") && p.endsWith("*") && file.contains(p.split("\\*")[1])) || (p.startsWith("*") && file.endsWith(p.replaceFirst("^\\*", ""))) || (p.endsWith("*") && file.startsWith(p.replaceFirst("\\*$", "")))) { return true; } } return false; } @Override public String toString() { return String.format("@AllowLocalFileAccess for %s (%s)", file, description); } }); } catch (CantDoItException e) { throw e; } } } public void checkFileDescriptorAccess(final FileDescriptor fd, final String description) throws CantDoItException { Class<?>[] classContext = getClassContext(); if (traceWithoutExplicitlyAllowedClass(classContext)) { checkClassContextPermissions(classContext, new Predicate<Class<?>>() { @Override public boolean apply(Class<?> input) { if (input.getAnnotation(AllowExternalProcess.class) != null || input.getAnnotation(AllowNetworkAccess.class) != null) { // AllowExternalProcess and AllowNetworkAccess imply @AllowLocalFileAccess({"%FD%"}), // since it's required. return true; } AllowLocalFileAccess a = input .getAnnotation(AllowLocalFileAccess.class); if (a == null) { return false; } for (String p : a.paths()) { if (p.equals("%FD%")) { return true; } } return false; } @Override public String toString() { return String.format( "@AllowLocalFileAccess for FileDescriptor(%s) (%s)", fd, description); } }); } } @Override public void checkRead(String file, Object context) { checkFileAccess(file, "read"); } @Override public void checkRead(String file) { checkRead(file, null); } @Override public void checkRead(final FileDescriptor fd) { checkFileDescriptorAccess(fd, "read"); } @Override public void checkDelete(final String file) { checkFileAccess(file, "delete"); } @Override public void checkWrite(FileDescriptor fd) { checkFileDescriptorAccess(fd, "write"); } @Override public void checkWrite(String file) { checkFileAccess(file, "write"); } // }} // {{ Allowed only via {@link @AllowExternalProcess} @Override public void checkExec(final String cmd) throws CantDoItException { Class<?>[] classContext = getClassContext(); if (traceWithoutExplicitlyAllowedClass(classContext)) { checkClassContextPermissions(classContext, new Predicate<Class<?>>() { @Override public boolean apply(Class<?> input) { AllowExternalProcess a = input .getAnnotation(AllowExternalProcess.class); if (a != null) { return true; } else { return false; } } @Override public String toString() { return String.format("@AllowExternalProcess for %s (exec)", cmd); } }); } } // }} // {{ Closely Monitored @Override public void checkExit(int status) { log.info("%s: exit(%d)", currentTest(getClassContext()), status); } @Override public void checkLink(String lib) { log.info("%s: System.loadLibrary(\"%s\")", currentTest(getClassContext()), lib); } @Override public void checkAwtEventQueueAccess() { log.info("%s: AwtEventQueue Access", currentTest(getClassContext())); } @Override public void checkPrintJobAccess() { log.info("%s: PrintJob Access", currentTest(getClassContext())); } @Override public void checkSystemClipboardAccess() { log.info("%s: SystemClipboard Access", currentTest(getClassContext())); } @Override public boolean checkTopLevelWindow(Object window) { log.info("%s: checkTopLevelWindow aka AWTPermission(\"showWindowWithoutWarningBanner\")", currentTest(getClassContext())); return true; } // }} // {{ Always Allowed @Override public void checkAccess(Thread t) {} @Override public void checkAccess(ThreadGroup g) {} @Override public void checkMemberAccess(Class<?> clazz, int which) {} @Override public void checkPackageAccess(String pkg) {} @Override public void checkPackageDefinition(String pkg) {} @Override public void checkSetFactory() {} @Override public void checkCreateClassLoader() { // This is re-set on classloader creation in case the classpath has changed. // In particular, Maven's Surefire booter changes the classpath after the security // manager has been initialized. CP_PARTS.set(getClassPath()); } @Override public void checkPropertiesAccess() {} @Override public void checkPropertyAccess(String key) {} @Override public void checkSecurityAccess(String target) {} // }} // {{ Undecided -- Can these be called in the real functions' stead? @Override public void checkPermission(Permission perm, Object context) {} @Override public void checkPermission(Permission perm) {} // }} private boolean isClassWhitelisted(Class<?> clazz) { if (getWhitelistedClasses().contains(clazz)) { return true; } Class<?> enclosingClass = clazz.getEnclosingClass(); if (enclosingClass != null) { return isClassWhitelisted(enclosingClass); } return false; } private boolean traceWithoutExplicitlyAllowedClass(Class<?>[] classContext) { for (Class<?> clazz : classContext) { if (isClassWhitelisted(clazz)) { return false; } } return true; } private void checkClassContextPermissions(final Class<?>[] classContext, final Predicate<Class<?>> classAuthorized) throws CantDoItException { // Only check permissions when we're running in the context of a JUnit test. boolean encounteredTestMethodRunner = false; for (Class<?> clazz : classContext) { if (clazz.equals(org.junit.runners.ParentRunner.class) || clazz.equals(org.junit.internal.runners.statements.RunAfters.class) || clazz.equals(org.junit.internal.runners.statements.RunBefores.class)) { encounteredTestMethodRunner = true; } } if (!encounteredTestMethodRunner) { return; } for (Class<?> clazz : classContext) { if (classAuthorized.apply(clazz)) { return; } } // No class on the stack trace is properly authorized, throw an exception. CantDoItException e = new CantDoItException(String.format("No class in the class context satisfies %s", classAuthorized)); if (this.reporting) { StackTraceElement testClassStackFrame = currentTest(classContext); String testName = "unknown test"; if (testClassStackFrame != null) { testName = format("%s.%s():%d", testClassStackFrame.getClassName(), testClassStackFrame.getMethodName(), testClassStackFrame.getLineNumber()); } log.error("%s: No %s at %s", testName, classAuthorized, testName); for (StackTraceElement el : currentThread().getStackTrace()) { log.debug("%s: Stack: %s.%s():%d", testName, el.getClassName(), el.getMethodName(), el.getLineNumber()); } for (Class<?> cl : classContext) { log.debug("%s: Class Context: %s %s", testName, cl.getCanonicalName(), cl); } } throw e; } public StackTraceElement currentTest(Class<?>[] classContext) { // The first class right before TestMethodRunner in the class context // array is the class that contains our test. Class<?> testClass = null; for (Class<?> clazz : classContext) { if (clazz == org.junit.runners.ParentRunner.class || clazz == org.junit.internal.runners.statements.RunAfters.class || clazz == org.junit.internal.runners.statements.RunBefores.class) { break; } testClass = clazz; } final StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace(); StackTraceElement testClassStackFrame = null; for (StackTraceElement el : stackTrace) { if (el.getClassName().equals(testClass.getCanonicalName())) { testClassStackFrame = el; } } return testClassStackFrame; } public static class CantDoItException extends RuntimeException { private static final long serialVersionUID = -8858380898538847118L; public CantDoItException() { } public CantDoItException(String s) { super(s); } public CantDoItException(String s, CantDoItException e) { super(s, e); } } }