package scotch.compiler;
import static java.util.Arrays.asList;
import static java.util.stream.Collectors.joining;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThat;
import static scotch.compiler.Compiler.compiler;
import static scotch.compiler.syntax.reference.DefinitionReference.signatureRef;
import static scotch.util.TestUtil.classDef;
import static scotch.util.TestUtil.classRef;
import static scotch.util.TestUtil.ctorDef;
import static scotch.util.TestUtil.dataDef;
import static scotch.util.TestUtil.dataRef;
import static scotch.util.TestUtil.valueRef;
import static scotch.symbol.Symbol.symbol;
import static scotch.compiler.syntax.type.Types.sum;
import static scotch.compiler.text.TextUtil.quote;
import java.util.List;
import java.util.function.Function;
import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.TestName;
import scotch.compiler.error.SyntaxError;
import scotch.compiler.syntax.definition.DataFieldDefinition;
import scotch.compiler.syntax.definition.DefinitionGraph;
import scotch.compiler.syntax.definition.ValueDefinition;
import scotch.compiler.syntax.definition.ValueSignature;
import scotch.compiler.syntax.reference.DefinitionReference;
import scotch.compiler.syntax.scope.Scope;
import scotch.compiler.syntax.value.Value;
import scotch.symbol.SymbolResolver;
import scotch.compiler.syntax.type.Type;
public abstract class CompilerTest<Resolver extends SymbolResolver> {
@Rule
public final TestName testName = new TestName();
protected final Type intType = sum("scotch.data.int.Int");
protected final Type doubleType = sum("scotch.data.double.Double");
protected final Type boolType = sum("scotch.data.bool.Bool");
protected final Type stringType = sum("scotch.data.string.String");
protected DefinitionGraph graph;
protected Function<String[], Compiler> compilerFactory;
protected Resolver resolver;
@Before
public void setUp() {
resolver = initResolver();
compilerFactory = initCompilerFactory();
}
protected void compile(String... lines) {
graph = compile().apply(compilerFactory.apply(lines));
}
protected abstract Function<Compiler, DefinitionGraph> compile();
protected Scope getScope(DefinitionReference reference) {
return graph.getScope(reference);
}
protected ValueDefinition getValueDefinition(String name) {
return graph.getDefinition(valueRef(name)).get();
}
protected Function<String[], Compiler> initCompilerFactory() {
return lines -> compiler(resolver, lines);
}
protected abstract Resolver initResolver();
protected void shouldBeDefined(DefinitionReference reference, String name) {
assertThat(
"Symbol " + quote(name) + " is not defined in scope " + reference,
getScope(reference).isDefined(symbol(name)),
is(true)
);
}
protected void shouldHaveClass(String className, List<Type> arguments, List<DefinitionReference> members) {
assertThat(graph.getDefinition(classRef(className)).get(), is(
classDef(className, arguments, members)
));
}
protected void shouldHaveData(int ordinal, String name, List<Type> parameters, List<DataFieldDefinition> fields) {
assertThat(graph.hasErrors(), is(false));
assertThat(graph.getDefinition(dataRef(name)).get(), is(dataDef(name, parameters, asList(ctorDef(ordinal, name, name, fields)))));
}
protected void shouldHaveErrors(SyntaxError... errors) {
assertThat(graph.hasErrors(), is(true));
assertThat(graph.getErrors(), contains(errors));
}
protected void shouldHaveSignature(String name, Type type) {
assertThat(((ValueSignature) graph.getDefinition(signatureRef(symbol(name))).get()).getType(), is(type));
}
protected void shouldHaveValue(String name, Type type) {
shouldHaveValue(name);
assertThat(graph.getValue(valueRef(name)).get(), is(type));
}
protected void shouldHaveValue(String name) {
assertThat("Graph did not define value " + quote(name), graph.getValue(valueRef(name)).isPresent(), is(true));
}
protected void shouldHaveValue(String name, Value body) {
shouldHaveValue(name);
assertThat(getValueDefinition(name).getBody(), is(body));
}
protected void shouldNotHaveErrors() {
assertThat(
"Definition graph has " + graph.getErrors().size() + " errors!\n\t" + graph.getErrors().stream()
.map(error -> error.prettyPrint() + "\n\tDebugTrace:\n\t\t" + error.getStackTrace().stream()
.limit(10)
.map(Object::toString)
.collect(joining("\n\t\t")) + "\n\t\t...")
.collect(joining("\n\t")),
graph.hasErrors(),
is(false)
);
}
}