package be.raildelays.javafx.test; import javafx.application.Platform; import javafx.embed.swing.JFXPanel; import org.junit.Rule; import org.junit.rules.TestRule; import org.junit.runner.Description; import org.junit.runners.model.Statement; import javax.swing.*; import java.lang.reflect.Method; import java.net.URLClassLoader; import java.util.concurrent.CountDownLatch; /** * A JUnit {@link Rule} for running tests on the JavaFX thread and performing * JavaFX initialisation. To include in your test case, add the following code: * <p> * <pre> * {@literal @}Rule * public JavaFXThreadingRule jfxRule = new JavaFXThreadingRule(); * </pre> * * @author Andy Till */ public class JavaFXThreadingRule implements TestRule { /** * Flag for setting up the JavaFX, we only need to do this once for all tests. */ private static boolean loaded; @Override public Statement apply(Statement statement, Description description) { return new OnJFXThreadStatement(statement); } private static class OnJFXThreadStatement extends Statement { private final Statement statement; public OnJFXThreadStatement(Statement aStatement) { statement = aStatement; } private Throwable rethrownException = null; @Override public void evaluate() throws Throwable { CountDownLatch countDownLatch = new CountDownLatch(1); if (!loaded) { setupJavaFX(); loaded = true; } Platform.runLater(() -> { try { statement.evaluate(); } catch (Throwable e) { rethrownException = e; } countDownLatch.countDown(); }); countDownLatch.await(); // if an exception was thrown by the statement during evaluation, // then re-throw it to fail the test if (rethrownException != null) { throw rethrownException; } } protected void setupJavaFX() throws InterruptedException { long timeMillis = System.currentTimeMillis(); CountDownLatch latch = new CountDownLatch(1); SwingUtilities.invokeLater(() -> { // initializes JavaFX environment new JFXPanel(); latch.countDown(); }); System.out.println("JavaFX initialising..."); latch.await(); System.out.println("JavaFX is initialised in " + (System.currentTimeMillis() - timeMillis) + "ms"); } } }