package org.enumerable.lambda.enumerable.jruby;
import static java.lang.System.*;
import static org.enumerable.lambda.exception.UncheckedException.*;
import static org.jruby.javasupport.JavaEmbedUtils.*;
import java.io.InputStreamReader;
import java.io.StringWriter;
import java.io.Writer;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import javax.script.ScriptEngine;
import javax.script.ScriptException;
import org.enumerable.lambda.Fn0;
import org.enumerable.lambda.Fn1;
import org.enumerable.lambda.Fn2;
import org.enumerable.lambda.support.jruby.JRubyTest;
import org.enumerable.lambda.weaving.Debug;
import org.jruby.Ruby;
import org.jruby.RubyArray;
import org.jruby.RubyHash;
import org.jruby.RubyProc;
import org.jruby.runtime.builtin.IRubyObject;
import org.junit.Before;
public class JRubyTestBase {
public static boolean debug = Debug.debug;
public ScriptEngine rb = JRubyTest.getJRubyEngine();
@Before
public void monkeyPatchJRubyEnumerableToUseEnumerableJava() throws ScriptException {
require(enumerableJava());
}
public String enumerableJava() {
return "enumerable_java";
}
public void load(String test) throws ScriptException {
rb.eval(new InputStreamReader(JRubyTestBase.class.getResourceAsStream(test)));
}
public void require(String file) throws ScriptException {
rb.eval("require '" + file + "'");
}
public Object eval(String script) throws ScriptException {
return rb.eval(script);
}
void testUnit(String file, String testClass) throws ScriptException {
StringWriter writer = new StringWriter();
Writer originalWriter = rb.getContext().getWriter();
rb.getContext().setWriter(writer);
try {
require(file);
require("test/unit/ui/console/testrunner");
beforeRunningTestUnit();
eval("r = Test::Unit::UI::Console::TestRunner.run(" + testClass + ")");
eval("raise r.to_s unless r.passed?");
if (debug)
out.println(writer);
} catch (ScriptException e) {
out.println(writer);
throw uncheck(e);
} finally {
rb.getContext().setWriter(originalWriter);
}
}
protected void beforeRunningTestUnit() throws ScriptException {
}
public static void debug(String msg) {
if (debug)
out.println(msg);
}
@SuppressWarnings("serial")
public static Fn0<Object> toFn0(final RubyProc proc) {
return new Fn0<Object>() {
public Object call() {
Ruby ruby = proc.getRuntime();
return rubyToJava(proc.call(ruby.getThreadService().getCurrentContext(), new IRubyObject[0]));
}
};
}
@SuppressWarnings("serial")
public static Fn1<Object, Object> toFn1(final RubyProc proc) {
return new Fn1<Object, Object>() {
public Object call(Object a1) {
Ruby ruby = proc.getRuntime();
return rubyToJava(proc.call(ruby.getThreadService().getCurrentContext(),
new IRubyObject[] { javaInRubyToRuby(ruby, a1) }));
}
};
}
@SuppressWarnings("serial")
public static Fn2<Object, Object, Object> toFn2(final RubyProc proc) {
return new Fn2<Object, Object, Object>() {
public Object call(Object a1, Object a2) {
Ruby ruby = proc.getRuntime();
return rubyToJava(proc.call(ruby.getThreadService().getCurrentContext(), new IRubyObject[] {
javaInRubyToRuby(ruby, a1), javaInRubyToRuby(ruby, a2) }));
}
};
}
protected static IRubyObject javaInRubyToRuby(Ruby ruby, Object value) {
if (value instanceof List<?> && !(value instanceof RubyArray)) {
RubyArray array = RubyArray.newArray(ruby);
array.addAll((Collection<?>) value);
return array;
}
if (value instanceof Map<?, ?> && !(value instanceof RubyHash)) {
RubyHash hash = RubyHash.newHash(ruby);
hash.putAll((Map<?, ?>) value);
return hash;
}
return javaToRuby(ruby, value);
}
}