package scotch.compiler.syntax.type; import static java.util.stream.Collectors.joining; import static java.util.stream.Collectors.toList; import static java.util.stream.Collectors.toSet; import static me.qmx.jitescript.util.CodegenUtils.p; import static me.qmx.jitescript.util.CodegenUtils.sig; import static scotch.util.Pair.pair; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.function.Function; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import lombok.EqualsAndHashCode; import me.qmx.jitescript.CodeBlock; import scotch.symbol.Symbol; import scotch.symbol.type.SumTypeDescriptor; import scotch.symbol.type.TypeDescriptors; import scotch.util.Pair; @EqualsAndHashCode(callSuper = false) public class SumType extends Type { private static void shouldBeSumName(Symbol symbol) { if (!symbol.isSumName()) { throw new IllegalArgumentException("Sum type should have upper-case name, be tuple, or list: got '" + symbol.getMemberName() + "'"); } } private static List<Pair<Type, Type>> zip(List<Type> left, List<Type> right) { List<Pair<Type, Type>> result = new ArrayList<>(); Iterator<Type> leftIterator = left.iterator(); Iterator<Type> rightIterator = right.iterator(); while (leftIterator.hasNext()) { result.add(pair(leftIterator.next(), rightIterator.next())); } return result; } private final Symbol symbol; private final List<Type> parameters; SumType(Symbol symbol, List<Type> parameters) { shouldBeSumName(symbol); this.symbol = symbol; this.parameters = ImmutableList.copyOf(parameters); } @Override public <T> T accept(Visitor<T> visitor) { return visitor.visit(this); } @Override public HeadApplication apply(Type head, TypeScope scope) { return head.applyWith(this, scope); } public HeadZip applyZip(Type head, TypeScope scope) { return head.applyZipWith(this, scope); } @Override public Type flatten() { return new SumType(symbol, parameters.stream() .map(Type::flatten) .collect(toList())); } @Override public CodeBlock generateBytecode() { return new CodeBlock() {{ ldc(symbol.getCanonicalName()); invokestatic(p(Symbol.class), "symbol", sig(Symbol.class, String.class)); newobj(p(ArrayList.class)); dup(); invokespecial(p(ArrayList.class), "<init>", sig(void.class)); parameters.forEach(parameter -> { dup(); append(parameter.generateBytecode()); invokeinterface(p(List.class), "add", sig(boolean.class, Object.class)); pop(); }); invokestatic(p(TypeDescriptors.class), "sum", sig(SumTypeDescriptor.class, Symbol.class, List.class)); }}; } @Override public Map<String, Type> getContexts(Type type, TypeScope scope) { return ImmutableMap.of(); } public List<Type> getParameters() { return parameters; } @Override public Type mapVariables(Function<VariableType, Type> mapper) { return withParameters(parameters.stream() .map(parameter -> parameter.mapVariables(mapper)) .collect(toList())); } public Symbol getSymbol() { return symbol; } @Override public Type qualifyNames(TypeQualifier qualifier) { return withSymbol(qualifier.qualifyType(symbol)) .withParameters(parameters.stream() .map(argument -> argument.qualifyNames(qualifier)) .collect(toList())); } @Override public SumTypeDescriptor toDescriptor() { return TypeDescriptors.sum(symbol, parameters.stream() .map(Type::toDescriptor) .collect(toList())); } public SumType withParameters(List<Type> arguments) { return new SumType(symbol, arguments); } public SumType withSymbol(Symbol symbol) { return new SumType(symbol, parameters); } @Override protected boolean contains(VariableType type) { return parameters.stream() .map(Type::simplify) .anyMatch(argument -> argument.equals(type)); } @Override protected Type flatten(List<Type> types) { return withParameters(new ArrayList<Type>() {{ addAll(parameters); addAll(types); }}); } @Override protected List<Type> flatten_() { return ImmutableList.of(flatten()); } @Override protected Set<Pair<VariableType, Symbol>> gatherContext_() { return parameters.stream() .flatMap(parameter -> parameter.gatherContext_().stream()) .collect(toSet()); } @Override protected Type generate(TypeScope scope, Set<Type> visited) { return withParameters(parameters.stream() .map(parameter -> parameter.generate(scope, visited)) .collect(toList())).flatten(); } @Override protected Type genericCopy(TypeScope scope, Map<Type, Type> mappings) { return new SumType(symbol, parameters.stream() .map(parameter -> parameter.genericCopy(scope, mappings)) .collect(toList())); } @Override protected String toParenthesizedString() { return toString_(); } @Override protected String toString_() { if (symbol.isTuple()) { return "(" + parameters.stream().map(Type::toString_).collect(joining(", ")) + ")"; } else if (symbol.isList()) { return "[" + parameters.stream().map(Type::toString_).collect(joining(", ")) + "]"; } else if (parameters.isEmpty()) { return symbol.getSimpleName(); } else { return symbol.getSimpleName() + " " + parameters.stream().map(Type::toString_).collect(joining(" ")); } } @Override protected Unification unifyWith(ConstructorType target, TypeScope scope) { return target.apply(this, scope); } @Override protected Unification unifyWith(SumType target, TypeScope scope) { if (symbol.equals(target.symbol)) { if (parameters.size() == target.parameters.size()) { List<Pair<Type, Type>> zip = zip(target.parameters, parameters); for (Pair<Type, Type> pair : zip) { Unification result = pair.into((left, right) -> left.unify(right, scope)); if (!result.isUnified()) { return result; } } return Unification.unified(target); } else { return Unification.mismatch(target, this); } } else { return Unification.mismatch(target, this); } } @Override protected Unification unifyWith(FunctionType target, TypeScope scope) { return Unification.mismatch(target, this); } @Override protected Unification unifyWith(VariableType target, TypeScope scope) { if (contains(target)) { return Unification.circular(target, this); } else { return Types.unifyVariable(this, target, scope); } } @Override protected Unification unify_(Type type, TypeScope scope) { return type.unifyWith(this, scope); } @Override protected Optional<List<Pair<Type, Type>>> zipWith(ConstructorType target, TypeScope scope) { return target.applyZip(this, scope); } @Override protected Optional<List<Pair<Type, Type>>> zipWith(SumType target, TypeScope scope) { if (equals(target)) { return Optional.of(ImmutableList.of(pair(target, this))); } else { return Optional.empty(); } } @Override protected Optional<List<Pair<Type, Type>>> zip_(Type other, TypeScope scope) { return other.zipWith(this, scope); } }