package scotch.compiler.syntax.type;
import static me.qmx.jitescript.util.CodegenUtils.p;
import static me.qmx.jitescript.util.CodegenUtils.sig;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSortedSet;
import lombok.EqualsAndHashCode;
import me.qmx.jitescript.CodeBlock;
import scotch.symbol.Symbol;
import scotch.symbol.type.FunctionTypeDescriptor;
import scotch.symbol.type.TypeDescriptor;
import scotch.symbol.type.TypeDescriptors;
import scotch.util.Pair;
@EqualsAndHashCode(callSuper = false)
public class FunctionType extends Type {
private final Type argument;
private final Type result;
FunctionType(Type argument, Type result) {
this.argument = argument;
this.result = result;
}
@Override
public <T> T accept(Visitor<T> visitor) {
return visitor.visit(this);
}
@Override
public Type flatten() {
return new FunctionType(argument.flatten(), result.flatten());
}
@Override
public CodeBlock generateBytecode() {
return new CodeBlock() {{
append(argument.generateBytecode());
append(result.generateBytecode());
invokestatic(p(TypeDescriptors.class), "fn", sig(FunctionTypeDescriptor.class, TypeDescriptor.class, TypeDescriptor.class));
}};
}
public Type getArgument() {
return argument;
}
@Override
public Map<String, Type> getContexts(Type type, TypeScope scope) {
Map<String, Type> map = new HashMap<>();
if (type instanceof FunctionType) {
map.putAll(argument.getContexts(((FunctionType) type).getArgument(), scope));
map.putAll(result.getContexts(((FunctionType) type).getResult(), scope));
}
return map;
}
@Override
public List<Pair<VariableType, Symbol>> getInstanceMap() {
List<Pair<VariableType, Symbol>> instances = new ArrayList<>();
instances.addAll(argument.getInstanceMap());
instances.addAll(result.getInstanceMap());
return instances;
}
public Type getResult() {
return result;
}
@Override
public Type mapVariables(Function<VariableType, Type> mapper) {
return new FunctionType(argument.mapVariables(mapper), result.mapVariables(mapper));
}
@Override
public Type qualifyNames(TypeQualifier qualifier) {
return withArgument(argument.qualifyNames(qualifier)).withResult(result.qualifyNames(qualifier));
}
@Override
public TypeDescriptor toDescriptor() {
return TypeDescriptors.fn(argument.toDescriptor(), result.toDescriptor());
}
public FunctionType withArgument(Type argument) {
return new FunctionType(argument, result);
}
public FunctionType withResult(Type result) {
return new FunctionType(argument, result);
}
@Override
protected boolean contains(VariableType type) {
return argument.contains(type) || result.contains(type);
}
@Override
protected List<Type> flatten_() {
return ImmutableList.of(flatten());
}
@Override
protected Set<Pair<VariableType, Symbol>> gatherContext_() {
Set<Pair<VariableType, Symbol>> context = new HashSet<>();
context.addAll(argument.gatherContext_());
context.addAll(result.gatherContext_());
return ImmutableSortedSet.copyOf(Types::sort, context);
}
@Override
protected Type generate(TypeScope scope, Set<Type> visited) {
return new FunctionType(argument.generate(scope), result.generate(scope)).flatten();
}
@Override
protected Type genericCopy(TypeScope scope, Map<Type, Type> mappings) {
return new FunctionType(
argument.genericCopy(scope, mappings),
result.genericCopy(scope, mappings)
);
}
@Override
protected String toParenthesizedString() {
return "(" + argument.toParenthesizedString() + " -> " + result.toString_() + ")";
}
@Override
protected String toString_() {
return argument.toParenthesizedString() + " -> " + result.toString_();
}
@Override
protected Unification unifyWith(ConstructorType target, TypeScope scope) {
return Unification.mismatch(target, this);
}
@Override
protected Unification unifyWith(SumType target, TypeScope scope) {
return Unification.mismatch(target, this);
}
@Override
protected Unification unifyWith(FunctionType target, TypeScope scope) {
return target.argument.unify_(argument, scope).map(
argumentResult -> target.result.unify_(result, scope).map(
resultResult -> Unification.unified(Types.fn(argumentResult, resultResult))
)
);
}
@Override
protected Unification unifyWith(VariableType target, TypeScope scope) {
if (contains(target)) {
return Unification.circular(target, this);
} else {
return Types.unifyVariable(this, target, scope);
}
}
@Override
protected Unification unify_(Type type, TypeScope scope) {
return type.unifyWith(this, scope);
}
@Override
protected Optional<List<Pair<Type, Type>>> zipWith(FunctionType target, TypeScope scope) {
return target.argument.zip_(argument, scope).flatMap(
argumentList -> target.result.zip_(result, scope).map(
resultList -> new ArrayList<Pair<Type, Type>>() {{
addAll(argumentList);
addAll(resultList);
}}));
}
@Override
protected Optional<List<Pair<Type, Type>>> zip_(Type other, TypeScope scope) {
return other.zipWith(this, scope);
}
}