package scotch.compiler.syntax.value; import static java.util.Collections.reverse; import static java.util.stream.Collectors.toList; import static scotch.compiler.syntax.builder.BuilderUtil.require; import static scotch.compiler.syntax.definition.Definitions.scopeDef; import static scotch.compiler.syntax.reference.DefinitionReference.scopeRef; import static scotch.compiler.syntax.value.Values.matcher; import static scotch.compiler.syntax.type.Types.fn; import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.function.Function; import java.util.function.Supplier; import com.google.common.collect.ImmutableList; import lombok.EqualsAndHashCode; import lombok.ToString; import scotch.compiler.analyzer.DependencyAccumulator; import scotch.compiler.analyzer.NameAccumulator; import scotch.compiler.analyzer.OperatorAccumulator; import scotch.compiler.analyzer.PrecedenceParser; import scotch.compiler.analyzer.PrecedenceParser.ArityMismatch; import scotch.compiler.analyzer.NameQualifier; import scotch.compiler.analyzer.TypeChecker; import scotch.compiler.intermediate.IntermediateGenerator; import scotch.compiler.intermediate.IntermediateValue; import scotch.compiler.syntax.Scoped; import scotch.compiler.syntax.builder.SyntaxBuilder; import scotch.compiler.syntax.definition.Definition; import scotch.compiler.syntax.pattern.PatternCase; import scotch.compiler.syntax.pattern.PatternReducer; import scotch.compiler.syntax.reference.DefinitionReference; import scotch.compiler.text.SourceLocation; import scotch.symbol.Symbol; import scotch.compiler.syntax.type.Type; @EqualsAndHashCode(callSuper = false) @ToString(exclude = "sourceLocation", doNotUseGetters = true) public class PatternMatcher extends Value implements Scoped { public static Builder builder() { return new Builder(); } private final SourceLocation sourceLocation; private final Symbol symbol; private final List<Argument> arguments; private final List<PatternCase> patternCases; private final Type type; PatternMatcher(SourceLocation sourceLocation, Symbol symbol, List<Argument> arguments, List<PatternCase> patternCases, Type type) { this.sourceLocation = sourceLocation; this.symbol = symbol; this.arguments = ImmutableList.copyOf(arguments); this.patternCases = ImmutableList.copyOf(patternCases); this.type = type; } @Override public Value accumulateDependencies(DependencyAccumulator state) { throw new UnsupportedOperationException(); } @Override public Value accumulateNames(NameAccumulator state) { return state.scoped(this, () -> map( argument -> argument.accumulateNames(state), patternCase -> patternCase.accumulateNames(state) )); } @Override public IntermediateValue generateIntermediateCode(IntermediateGenerator state) { throw new UnsupportedOperationException(); } @Override public Value bindMethods(TypeChecker typeChecker) { throw new UnsupportedOperationException(); } @Override public Value bindTypes(TypeChecker typeChecker) { throw new UnsupportedOperationException(); } @Override public Value checkTypes(TypeChecker typeChecker) { throw new UnsupportedOperationException(); } @Override public Value defineOperators(OperatorAccumulator state) { return state.scoped(this, () -> map( argument -> argument, patternCase -> patternCase.defineOperators(state) )); } public List<Argument> getArguments() { return arguments; } @Override public Definition getDefinition() { return scopeDef(sourceLocation, symbol); } @Override public DefinitionReference getReference() { return scopeRef(symbol); } @Override public SourceLocation getSourceLocation() { return sourceLocation; } public Symbol getSymbol() { return symbol; } @Override public Type getType() { return type; } @Override public Value parsePrecedence(PrecedenceParser state) { Symbol s = symbol.getMemberNames().size() == 1 ? state.reserveSymbol() : symbol; patternCases.stream() .filter(pattern -> pattern.getArity() != arguments.size()) .map(pattern -> new ArityMismatch(s, arguments.size(), pattern.getArity(), pattern.getSourceLocation())) .forEach(state::error); return state.named(s, () -> state.scoped(this, () -> withSymbol(s).withPatternCases(patternCases.stream() .map(matcher -> matcher.parsePrecedence(state)) .collect(toList())))); } @Override public Value qualifyNames(NameQualifier state) { return state.named(symbol, () -> state.scoped(this, () -> map( argument -> argument.qualifyNames(state), patternCase -> patternCase.qualifyNames(state) ))); } @Override public Value reducePatterns(PatternReducer reducer) { reducer.beginPattern(this); try { patternCases.forEach(patternCase -> patternCase.reducePatterns(reducer)); return reducer.reducePattern(); } finally { reducer.endPattern(); } } @Override public Value unwrap() { return withPatternCases( patternCases.stream() .map(matcher -> matcher.withBody(matcher.getBody().unwrap())) .collect(toList()) ); } @Override public WithArguments withArguments() { return WithArguments.withArguments(this); } public PatternMatcher withArguments(List<Argument> arguments) { return new PatternMatcher(sourceLocation, symbol, arguments, patternCases, type); } public PatternMatcher withPatternCases(List<PatternCase> patternCases) { return new PatternMatcher(sourceLocation, symbol, arguments, patternCases, type); } public PatternMatcher withSourceLocation(SourceLocation sourceLocation) { return new PatternMatcher(sourceLocation, symbol, arguments, patternCases, type); } public PatternMatcher withSymbol(Symbol symbol) { return new PatternMatcher(sourceLocation, symbol, arguments, patternCases, type); } @Override public PatternMatcher withType(Type type) { return new PatternMatcher(sourceLocation, symbol, arguments, patternCases, type); } private Type calculateType(Type returnType) { List<Argument> args = new ArrayList<>(arguments); reverse(args); return args.stream() .map(Argument::getType) .reduce(returnType, (result, arg) -> fn(arg, result)); } private PatternMatcher encloseArguments(TypeChecker state, Supplier<PatternMatcher> supplier) { return state.enclose(this, () -> { arguments.stream() .map(Argument::getType) .forEach(state::specialize); arguments.stream() .map(Argument::getSymbol) .forEach(state::addLocal); try { return supplier.get(); } finally { arguments.stream() .map(Argument::getType) .forEach(state::generalize); } }); } private PatternMatcher map(Function<Argument, Argument> argumentMapper, Function<PatternCase, PatternCase> patternCaseMapper) { return new PatternMatcher( sourceLocation, symbol, arguments.stream().map(argumentMapper).collect(toList()), patternCases.stream().map(patternCaseMapper).collect(toList()), type ); } public static class Builder implements SyntaxBuilder<PatternMatcher> { private Optional<SourceLocation> sourceLocation = Optional.empty(); private Optional<List<Argument>> arguments = Optional.empty(); private Optional<List<PatternCase>> patternCases = Optional.empty(); private Optional<Type> type = Optional.empty(); private Optional<Symbol> symbol = Optional.empty(); private Builder() { // intentionally empty } @Override public PatternMatcher build() { return matcher( require(sourceLocation, "Source location"), require(symbol, "Symbol"), require(type, "Pattern type"), require(arguments, "Arguments"), require(patternCases, "Pattern cases") ); } public Builder withArguments(List<Argument> arguments) { this.arguments = Optional.of(arguments); return this; } public Builder withPatterns(List<PatternCase> patterns) { this.patternCases = Optional.of(patterns); return this; } @Override public Builder withSourceLocation(SourceLocation sourceLocation) { this.sourceLocation = Optional.of(sourceLocation); return this; } public Builder withSymbol(Symbol symbol) { this.symbol = Optional.of(symbol); return this; } public Builder withType(Type type) { this.type = Optional.of(type); return this; } } }