package scotch.compiler.syntax.type; import static me.qmx.jitescript.util.CodegenUtils.p; import static me.qmx.jitescript.util.CodegenUtils.sig; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.function.Function; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSortedSet; import lombok.EqualsAndHashCode; import me.qmx.jitescript.CodeBlock; import scotch.symbol.Symbol; import scotch.symbol.type.FunctionTypeDescriptor; import scotch.symbol.type.TypeDescriptor; import scotch.symbol.type.TypeDescriptors; import scotch.util.Pair; @EqualsAndHashCode(callSuper = false) public class FunctionType extends Type { private final Type argument; private final Type result; FunctionType(Type argument, Type result) { this.argument = argument; this.result = result; } @Override public <T> T accept(Visitor<T> visitor) { return visitor.visit(this); } @Override public Type flatten() { return new FunctionType(argument.flatten(), result.flatten()); } @Override public CodeBlock generateBytecode() { return new CodeBlock() {{ append(argument.generateBytecode()); append(result.generateBytecode()); invokestatic(p(TypeDescriptors.class), "fn", sig(FunctionTypeDescriptor.class, TypeDescriptor.class, TypeDescriptor.class)); }}; } public Type getArgument() { return argument; } @Override public Map<String, Type> getContexts(Type type, TypeScope scope) { Map<String, Type> map = new HashMap<>(); if (type instanceof FunctionType) { map.putAll(argument.getContexts(((FunctionType) type).getArgument(), scope)); map.putAll(result.getContexts(((FunctionType) type).getResult(), scope)); } return map; } @Override public List<Pair<VariableType, Symbol>> getInstanceMap() { List<Pair<VariableType, Symbol>> instances = new ArrayList<>(); instances.addAll(argument.getInstanceMap()); instances.addAll(result.getInstanceMap()); return instances; } public Type getResult() { return result; } @Override public Type mapVariables(Function<VariableType, Type> mapper) { return new FunctionType(argument.mapVariables(mapper), result.mapVariables(mapper)); } @Override public Type qualifyNames(TypeQualifier qualifier) { return withArgument(argument.qualifyNames(qualifier)).withResult(result.qualifyNames(qualifier)); } @Override public TypeDescriptor toDescriptor() { return TypeDescriptors.fn(argument.toDescriptor(), result.toDescriptor()); } public FunctionType withArgument(Type argument) { return new FunctionType(argument, result); } public FunctionType withResult(Type result) { return new FunctionType(argument, result); } @Override protected boolean contains(VariableType type) { return argument.contains(type) || result.contains(type); } @Override protected List<Type> flatten_() { return ImmutableList.of(flatten()); } @Override protected Set<Pair<VariableType, Symbol>> gatherContext_() { Set<Pair<VariableType, Symbol>> context = new HashSet<>(); context.addAll(argument.gatherContext_()); context.addAll(result.gatherContext_()); return ImmutableSortedSet.copyOf(Types::sort, context); } @Override protected Type generate(TypeScope scope, Set<Type> visited) { return new FunctionType(argument.generate(scope), result.generate(scope)).flatten(); } @Override protected Type genericCopy(TypeScope scope, Map<Type, Type> mappings) { return new FunctionType( argument.genericCopy(scope, mappings), result.genericCopy(scope, mappings) ); } @Override protected String toParenthesizedString() { return "(" + argument.toParenthesizedString() + " -> " + result.toString_() + ")"; } @Override protected String toString_() { return argument.toParenthesizedString() + " -> " + result.toString_(); } @Override protected Unification unifyWith(ConstructorType target, TypeScope scope) { return Unification.mismatch(target, this); } @Override protected Unification unifyWith(SumType target, TypeScope scope) { return Unification.mismatch(target, this); } @Override protected Unification unifyWith(FunctionType target, TypeScope scope) { return target.argument.unify_(argument, scope).map( argumentResult -> target.result.unify_(result, scope).map( resultResult -> Unification.unified(Types.fn(argumentResult, resultResult)) ) ); } @Override protected Unification unifyWith(VariableType target, TypeScope scope) { if (contains(target)) { return Unification.circular(target, this); } else { return Types.unifyVariable(this, target, scope); } } @Override protected Unification unify_(Type type, TypeScope scope) { return type.unifyWith(this, scope); } @Override protected Optional<List<Pair<Type, Type>>> zipWith(FunctionType target, TypeScope scope) { return target.argument.zip_(argument, scope).flatMap( argumentList -> target.result.zip_(result, scope).map( resultList -> new ArrayList<Pair<Type, Type>>() {{ addAll(argumentList); addAll(resultList); }})); } @Override protected Optional<List<Pair<Type, Type>>> zip_(Type other, TypeScope scope) { return other.zipWith(this, scope); } }