package scotch.compiler.syntax.type;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toSet;
import static me.qmx.jitescript.util.CodegenUtils.p;
import static me.qmx.jitescript.util.CodegenUtils.sig;
import static scotch.util.Pair.pair;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import lombok.EqualsAndHashCode;
import me.qmx.jitescript.CodeBlock;
import scotch.symbol.Symbol;
import scotch.symbol.type.SumTypeDescriptor;
import scotch.symbol.type.TypeDescriptors;
import scotch.util.Pair;
@EqualsAndHashCode(callSuper = false)
public class SumType extends Type {
private static void shouldBeSumName(Symbol symbol) {
if (!symbol.isSumName()) {
throw new IllegalArgumentException("Sum type should have upper-case name, be tuple, or list: got '" + symbol.getMemberName() + "'");
}
}
private static List<Pair<Type, Type>> zip(List<Type> left, List<Type> right) {
List<Pair<Type, Type>> result = new ArrayList<>();
Iterator<Type> leftIterator = left.iterator();
Iterator<Type> rightIterator = right.iterator();
while (leftIterator.hasNext()) {
result.add(pair(leftIterator.next(), rightIterator.next()));
}
return result;
}
private final Symbol symbol;
private final List<Type> parameters;
SumType(Symbol symbol, List<Type> parameters) {
shouldBeSumName(symbol);
this.symbol = symbol;
this.parameters = ImmutableList.copyOf(parameters);
}
@Override
public <T> T accept(Visitor<T> visitor) {
return visitor.visit(this);
}
@Override
public HeadApplication apply(Type head, TypeScope scope) {
return head.applyWith(this, scope);
}
public HeadZip applyZip(Type head, TypeScope scope) {
return head.applyZipWith(this, scope);
}
@Override
public Type flatten() {
return new SumType(symbol, parameters.stream()
.map(Type::flatten)
.collect(toList()));
}
@Override
public CodeBlock generateBytecode() {
return new CodeBlock() {{
ldc(symbol.getCanonicalName());
invokestatic(p(Symbol.class), "symbol", sig(Symbol.class, String.class));
newobj(p(ArrayList.class));
dup();
invokespecial(p(ArrayList.class), "<init>", sig(void.class));
parameters.forEach(parameter -> {
dup();
append(parameter.generateBytecode());
invokeinterface(p(List.class), "add", sig(boolean.class, Object.class));
pop();
});
invokestatic(p(TypeDescriptors.class), "sum", sig(SumTypeDescriptor.class, Symbol.class, List.class));
}};
}
@Override
public Map<String, Type> getContexts(Type type, TypeScope scope) {
return ImmutableMap.of();
}
public List<Type> getParameters() {
return parameters;
}
@Override
public Type mapVariables(Function<VariableType, Type> mapper) {
return withParameters(parameters.stream()
.map(parameter -> parameter.mapVariables(mapper))
.collect(toList()));
}
public Symbol getSymbol() {
return symbol;
}
@Override
public Type qualifyNames(TypeQualifier qualifier) {
return withSymbol(qualifier.qualifyType(symbol))
.withParameters(parameters.stream()
.map(argument -> argument.qualifyNames(qualifier))
.collect(toList()));
}
@Override
public SumTypeDescriptor toDescriptor() {
return TypeDescriptors.sum(symbol, parameters.stream()
.map(Type::toDescriptor)
.collect(toList()));
}
public SumType withParameters(List<Type> arguments) {
return new SumType(symbol, arguments);
}
public SumType withSymbol(Symbol symbol) {
return new SumType(symbol, parameters);
}
@Override
protected boolean contains(VariableType type) {
return parameters.stream()
.map(Type::simplify)
.anyMatch(argument -> argument.equals(type));
}
@Override
protected Type flatten(List<Type> types) {
return withParameters(new ArrayList<Type>() {{
addAll(parameters);
addAll(types);
}});
}
@Override
protected List<Type> flatten_() {
return ImmutableList.of(flatten());
}
@Override
protected Set<Pair<VariableType, Symbol>> gatherContext_() {
return parameters.stream()
.flatMap(parameter -> parameter.gatherContext_().stream())
.collect(toSet());
}
@Override
protected Type generate(TypeScope scope, Set<Type> visited) {
return withParameters(parameters.stream()
.map(parameter -> parameter.generate(scope, visited))
.collect(toList())).flatten();
}
@Override
protected Type genericCopy(TypeScope scope, Map<Type, Type> mappings) {
return new SumType(symbol, parameters.stream()
.map(parameter -> parameter.genericCopy(scope, mappings))
.collect(toList()));
}
@Override
protected String toParenthesizedString() {
return toString_();
}
@Override
protected String toString_() {
if (symbol.isTuple()) {
return "(" + parameters.stream().map(Type::toString_).collect(joining(", ")) + ")";
} else if (symbol.isList()) {
return "[" + parameters.stream().map(Type::toString_).collect(joining(", ")) + "]";
} else if (parameters.isEmpty()) {
return symbol.getSimpleName();
} else {
return symbol.getSimpleName() + " " + parameters.stream().map(Type::toString_).collect(joining(" "));
}
}
@Override
protected Unification unifyWith(ConstructorType target, TypeScope scope) {
return target.apply(this, scope);
}
@Override
protected Unification unifyWith(SumType target, TypeScope scope) {
if (symbol.equals(target.symbol)) {
if (parameters.size() == target.parameters.size()) {
List<Pair<Type, Type>> zip = zip(target.parameters, parameters);
for (Pair<Type, Type> pair : zip) {
Unification result = pair.into((left, right) -> left.unify(right, scope));
if (!result.isUnified()) {
return result;
}
}
return Unification.unified(target);
} else {
return Unification.mismatch(target, this);
}
} else {
return Unification.mismatch(target, this);
}
}
@Override
protected Unification unifyWith(FunctionType target, TypeScope scope) {
return Unification.mismatch(target, this);
}
@Override
protected Unification unifyWith(VariableType target, TypeScope scope) {
if (contains(target)) {
return Unification.circular(target, this);
} else {
return Types.unifyVariable(this, target, scope);
}
}
@Override
protected Unification unify_(Type type, TypeScope scope) {
return type.unifyWith(this, scope);
}
@Override
protected Optional<List<Pair<Type, Type>>> zipWith(ConstructorType target, TypeScope scope) {
return target.applyZip(this, scope);
}
@Override
protected Optional<List<Pair<Type, Type>>> zipWith(SumType target, TypeScope scope) {
if (equals(target)) {
return Optional.of(ImmutableList.of(pair(target, this)));
} else {
return Optional.empty();
}
}
@Override
protected Optional<List<Pair<Type, Type>>> zip_(Type other, TypeScope scope) {
return other.zipWith(this, scope);
}
}