package org.testfun.jee;
import org.jboss.resteasy.plugins.server.embedded.SecurityDomain;
import org.jboss.resteasy.plugins.server.embedded.SimplePrincipal;
import org.jboss.resteasy.plugins.server.tjws.TJWSEmbeddedJaxrsServer;
import org.jboss.resteasy.spi.ResteasyDeployment;
import org.junit.rules.MethodRule;
import org.junit.runners.model.FrameworkMethod;
import org.junit.runners.model.Statement;
import org.testfun.jee.runner.DependencyInjector;
import org.testfun.jee.runner.jaxrs.JaxRsException;
import org.testfun.jee.runner.jaxrs.RestRequest;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.lang.reflect.Field;
import java.net.ServerSocket;
import java.security.Principal;
/**
* A JUnit rule that launches a JAX-RS server (using RESTeasy and TJWS) running in the same JVM as the test itself.
* Injection of EJBs and mocks into the JAX-RS resources requires running the test using the {@link EjbWithMockitoRunner} runner.
*/
public class JaxRsServer implements MethodRule {
private int port = 0;
private TJWSEmbeddedJaxrsServer jaxRsServer;
private Class[] resourceClasses;
private Class[] providerClasses;
private ExpectedClientResponseFailure expectedClientResponseFailure = ExpectedClientResponseFailure.none();
/**
* Creates a JaxRsServer and deploys the specified resource classes.
* @param resourceClasses one or more resource classes that should be deployed
*/
public static JaxRsServer forResources(Class... resourceClasses) {
return new JaxRsServer(resourceClasses);
}
private JaxRsServer(Class[] resourceClasses) {
this.resourceClasses = resourceClasses;
}
/**
* Optionally override the default selected port to bind to.
* @param port TCP port to listen to
* @return a new JaxRsServer
*/
public JaxRsServer port(int port) {
JaxRsServer newServer = new JaxRsServer(resourceClasses);
newServer.port = port;
newServer.providerClasses = this.providerClasses;
return newServer;
}
public JaxRsServer providers(Class... providerClasses) {
JaxRsServer newServer = new JaxRsServer(resourceClasses);
newServer.port = this.port;
newServer.providerClasses = providerClasses;
return newServer;
}
/**
* Gets the automatically-selected or manually-set TCP port used by the server.
* @return selected TCP port
*/
public int getPort() {
return port;
}
/**
* Constructs a new JSON REST request builder.
* @param uri base request URI
* @return REST request builder
*/
public RestRequest jsonRequest(String uri) {
return new RestRequest(uri, port).accept(MediaType.APPLICATION_JSON_TYPE);
}
/**
* Constructs a new FORM REST request (application/x-www-form-urlencoded) builder.
* @param uri base request URI
* @return REST request builder
*/
public RestRequest formRequest(String uri) {
return new RestRequest(uri, port).accept(MediaType.APPLICATION_FORM_URLENCODED_TYPE);
}
@Override
public Statement apply(Statement base, FrameworkMethod method, Object target) {
return new JaxRsServerStatement(expectedClientResponseFailure.apply(base, method, target));
}
/**
* Set expectation for REST failure with a particular status code and a substring that should appear in the failure message.
* @param expectedResponseStatus the HTTP status expected to be returned from the server
* @param expectedMessageSubstring a substring of the expected message
*/
public void expectFailureResponse(Response.Status expectedResponseStatus, String expectedMessageSubstring) {
expectedClientResponseFailure.expectFailureResponse(expectedResponseStatus, expectedMessageSubstring);
}
public void startJaxRsServer() {
jaxRsServer = new TJWSEmbeddedJaxrsServer();
jaxRsServer.setSecurityDomain(new SecurityDomain() {
public Principal authenticate(String username, String password) throws SecurityException {
return new SimplePrincipal(username);
}
public boolean isUserInRole(Principal username, String role) {
return true;
}
});
jaxRsServer.setPort(port);
jaxRsServer.start();
// If no port was set, than a free one was automatically selected - need to find it's number.
if (port == 0) {
Object server = getFromPrivateField(jaxRsServer, "server");
Object acceptor = getFromPrivateField(server, "acceptor");
ServerSocket socket = getFromPrivateField(acceptor, "socket");
port = socket.getLocalPort();
}
for (Class aClass : resourceClasses) {
Object resourceInstance;
try {
resourceInstance = aClass.newInstance();
} catch (Exception e1) {
throw new IllegalArgumentException(e1);
}
DependencyInjector.getInstance().injectDependencies(resourceInstance);
getDeployment().getRegistry().addSingletonResource(resourceInstance);
}
if (providerClasses != null) {
for (Class providerClass: providerClasses) {
getDeployment().getProviderFactory().registerProvider(providerClass);
}
}
}
public void shutdownJaxRsServer() {
jaxRsServer.stop();
getDeployment().stop();
}
public ResteasyDeployment getDeployment() {
return jaxRsServer.getDeployment();
}
private class JaxRsServerStatement extends Statement {
private final Statement next;
private JaxRsServerStatement(Statement next) {
this.next = next;
}
@Override
public void evaluate() throws Throwable {
startJaxRsServer();
try {
next.evaluate();
} finally {
shutdownJaxRsServer();
}
}
}
@SuppressWarnings("unchecked")
private <T, S> S getFromPrivateField(T obj, String fieldName) {
// Locate the field in through all the super classes
Field f = null;
Class<?> objClass = obj.getClass();
while(!objClass.equals(Object.class)) {
try {
f = objClass.getDeclaredField(fieldName);
} catch (NoSuchFieldException e) {
// Ignore and try again with super
}
objClass = objClass.getSuperclass();
}
if (f == null) throw new JaxRsException("Could not find field '" + fieldName + "' in: " + obj);
// Get the field's current accessibility
boolean previousAccessState;
try {
previousAccessState = f.isAccessible();
} catch (Exception e) {
throw new JaxRsException("Could not get field's accessibility: " + fieldName, e);
}
// Change accessibility to true and get the field's value
try {
f.setAccessible(true);
return (S)f.get(obj);
} catch (Exception e) {
throw new JaxRsException("Could not set field '" + fieldName + "'", e);
} finally {
// finally, restore field's accessibility
f.setAccessible(previousAccessState);
}
}
}