package scotch.compiler.syntax.pattern; import static java.util.Collections.reverse; import static java.util.stream.Collectors.toList; import static scotch.compiler.syntax.value.Values.apply; import static scotch.compiler.syntax.value.Values.conditional; import static scotch.compiler.syntax.value.Values.fn; import static scotch.compiler.syntax.value.Values.id; import static scotch.compiler.syntax.value.Values.raise; import static scotch.compiler.syntax.value.Values.scope; import static scotch.symbol.Symbol.symbol; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Deque; import java.util.List; import scotch.compiler.syntax.value.FunctionValue; import scotch.compiler.syntax.value.IsConstructor; import scotch.compiler.syntax.value.PatternMatcher; import scotch.compiler.syntax.value.Value; import scotch.compiler.text.SourceLocation; import scotch.compiler.syntax.type.VariableType; import scotch.compiler.syntax.util.SymbolGenerator; public class DefaultPatternReducer implements PatternReducer { private final SymbolGenerator generator; private final Deque<PatternState> patterns; public DefaultPatternReducer(SymbolGenerator generator) { this.generator = generator; patterns = new ArrayDeque<>(); } @Override public void addAssignment(CaptureMatch capture) { pattern().addAssignment(capture); } @Override public void addCondition(Value argument, Value value) { pattern().addCondition(argument, value); } @Override public void addCondition(IsConstructor constructor) { pattern().addCondition(constructor); } @Override public void addTaggedArgument(Value taggedArgument) { pattern().addTaggedArgument(taggedArgument); } @Override public void beginPattern(PatternMatcher matcher) { patterns.push(new PatternState(matcher)); } @Override public void beginPatternCase(PatternCase patternCase) { pattern().beginPatternCase(patternCase); } @Override public void endPattern() { patterns.pop(); } @Override public void endPatternCase() { pattern().endPatternCase(); } @Override public Value getTaggedArgument(Value argument) { return pattern().getTaggedArgument(argument); } @Override public void markFunction(FunctionValue function) { // intentionally empty } @Override public Value reducePattern() { return pattern().reducePattern(); } private PatternState pattern() { return patterns.peek(); } private VariableType reserveType() { return generator.reserveType(); } private final class CaseState implements ArgumentMap { private final PatternCase patternCase; private final List<Value> conditions; private final List<CaptureMatch> assignments; private final List<Value> taggedArguments; public CaseState(PatternCase patternCase) { this.patternCase = patternCase; this.conditions = new ArrayList<>(); this.assignments = new ArrayList<>(); this.taggedArguments = new ArrayList<>(); } public void addAssignment(CaptureMatch capture) { assignments.add(capture); } public void addCondition(Value argument, Value value) { conditions.add(apply( apply( id(value.getSourceLocation(), symbol("scotch.data.eq.(==)"), generator.reserveType()), argument, generator.reserveType() ), value, generator.reserveType() )); } public void addCondition(IsConstructor constructor) { conditions.add(constructor); } public void addTaggedArgument(Value taggedArgument) { taggedArguments.add(taggedArgument); } public SourceLocation getSourceLocation() { return SourceLocation.extent(new ArrayList<SourceLocation>() {{ addAll(conditions.stream().map(Value::getSourceLocation).collect(toList())); addAll(assignments.stream().map(CaptureMatch::getSourceLocation).collect(toList())); add(patternCase.getBody().getSourceLocation()); }}); } @Override public Value getTaggedArgument(Value argument) { return argument.mapTags(value -> { for (Value taggedArgument : taggedArguments) { if (value.equalsBeta(taggedArgument)) { return value.reTag(taggedArgument); } } return value; }); } public boolean isDefaultCase() { return conditions.isEmpty(); } public Value reducePattern() { return reduceBody(); } public Value reducePattern(Value result) { if (conditions.isEmpty()) { return reduceBody(); } else { Value resultCondition = conditions.get(0); for (Value condition : conditions.subList(1, conditions.size())) { resultCondition = apply( apply(id(condition.getSourceLocation(), symbol("scotch.data.bool.(&&)"), reserveType()), resultCondition, reserveType()), condition, reserveType() ); } return conditional( SourceLocation.extent(conditions.stream().map(Value::getSourceLocation).collect(toList())), resultCondition, reduceBody(), result, reserveType() ); } } private Value reduceBody() { Value result = patternCase.getBody(); List<CaptureMatch> reverseAssignments = new ArrayList<>(assignments); reverse(reverseAssignments); for (CaptureMatch match : reverseAssignments) { result = match.reducePattern(this, generator, result); } return scope(patternCase.getSourceLocation(), patternCase.getSymbol(), result); } } private final class PatternState { private final PatternMatcher matcher; private final List<CaseState> cases; private CaseState currentCase; public PatternState(PatternMatcher matcher) { this.matcher = matcher; this.cases = new ArrayList<>(); } public void addAssignment(CaptureMatch capture) { currentCase.addAssignment(capture); } public void addCondition(Value argument, Value value) { currentCase.addCondition(argument, value); } public void addCondition(IsConstructor constructor) { currentCase.addCondition(constructor); } public void addTaggedArgument(Value taggedArgument) { currentCase.addTaggedArgument(taggedArgument); } public void beginPatternCase(PatternCase patternCase) { currentCase = new CaseState(patternCase); } public void endPatternCase() { cases.add(currentCase); currentCase = null; } public Value getTaggedArgument(Value argument) { return currentCase.getTaggedArgument(argument); } public Value reducePattern() { Value result = calculateDefaultCase(); List<CaseState> reverseCases = new ArrayList<>(cases); int count = 0; reverse(reverseCases); for (CaseState patternCase : reverseCases) { if (count++ > 0 && patternCase.isDefaultCase()) { throw new PatternReductionException("Non-terminal default pattern case", patternCase.getSourceLocation()); // TODO message } else { result = patternCase.reducePattern(result); } } return fn( matcher.getSourceLocation(), matcher.getSymbol(), matcher.getArguments(), result ); } private Value calculateDefaultCase() { CaseState lastCase = cases.get(cases.size() - 1); if (lastCase.isDefaultCase()) { return lastCase.reducePattern(); } else { return raise(lastCase.getSourceLocation().getEndPoint(), "Incomplete match", generator.reserveType()); } } } }