package ysoserial.secmgr;
import java.security.Permission;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.Callable;
public class ExecCheckingSecurityManager extends SecurityManager {
public ExecCheckingSecurityManager() {
this(true);
}
public ExecCheckingSecurityManager(boolean throwException) {
this.throwException = throwException;
}
private final boolean throwException;
private final List<String> cmds = new LinkedList<String>();
public List<String> getCmds() {
return Collections.unmodifiableList(cmds);
}
@Override
public void checkPermission(final Permission perm) { }
@Override
public void checkPermission(final Permission perm, final Object context) { }
@Override
public void checkExec(final String cmd) {
super.checkExec(cmd);
cmds.add(cmd);
if (throwException) {
// throw a special exception to ensure we can detect exec() in the test
throw new ExecException(cmd);
}
};
@SuppressWarnings("serial")
public static class ExecException extends RuntimeException {
private final String threadName = Thread.currentThread().getName();
private final String cmd;
public ExecException(String cmd) { this.cmd = cmd; }
public String getCmd() { return cmd; }
public String getThreadName() { return threadName; }
@
Override
public String getMessage() {
return "executed `" + getCmd() + "` in [" + getThreadName() + "]";
}
}
public void wrap(final Runnable runnable) throws Exception {
wrap(new Callable<Void>(){
public Void call() throws Exception {
runnable.run();
return null;
}
});
}
public <T> T wrap(final Callable<T> callable) throws Exception {
SecurityManager sm = System.getSecurityManager(); // save sm
System.setSecurityManager(this);
try {
T result = callable.call();
if (throwException && ! getCmds().isEmpty()) {
throw new ExecException(getCmds().get(0));
}
return result;
} catch (Exception e) {
if (! (e instanceof ExecException) && throwException && ! getCmds().isEmpty()) {
throw new ExecException(getCmds().get(0));
} else {
throw e;
}
} finally {
System.setSecurityManager(sm); // restore sm
}
}
}