package scotch.compiler.syntax.definition;
import static java.util.Spliterators.spliterator;
import static java.util.stream.Collectors.toList;
import static lombok.AccessLevel.PRIVATE;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import com.google.common.collect.ImmutableList;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.ToString;
import scotch.compiler.error.SyntaxError;
import scotch.compiler.syntax.util.DefaultSymbolGenerator;
import scotch.compiler.syntax.util.SymbolGenerator;
import scotch.compiler.syntax.type.Type;
import scotch.compiler.syntax.reference.DefinitionReference;
import scotch.compiler.syntax.reference.ValueReference;
import scotch.compiler.syntax.scope.Scope;
public class DefinitionGraph {
public static DefinitionGraphBuilder createGraph(Collection<DefinitionEntry> entries) {
return new DefinitionGraphBuilder(entries);
}
public static SyntaxError cyclicDependency(DependencyCycle cycle) {
return new CyclicDependencyError(cycle);
}
private static DependencyCycle fromNodes(Collection<DefinitionNode> nodes) {
DependencyCycle.Builder builder = DependencyCycle.builder();
nodes.forEach(builder::addNode);
return builder.build();
}
public final Map<DefinitionReference, DefinitionEntry> definitions;
private final SymbolGenerator symbolGenerator;
private final List<SyntaxError> errors;
private DefinitionGraph(Collection<DefinitionEntry> entries, SymbolGenerator symbolGenerator, List<SyntaxError> errors) {
this.symbolGenerator = symbolGenerator;
this.errors = ImmutableList.copyOf(errors);
this.definitions = new LinkedHashMap<>();
entries.forEach(entry -> definitions.put(entry.getReference(), entry));
}
public DefinitionGraphBuilder copyWith(Collection<DefinitionEntry> entries) {
return createGraph(entries)
.withErrors(errors)
.withSequence(symbolGenerator);
}
public Optional<ValueDefinition> getDefinition(ValueReference reference) {
return getDefinition((DefinitionReference) reference).map(definition -> (ValueDefinition) definition);
}
public Optional<Definition> getDefinition(DefinitionReference reference) {
return Optional.ofNullable(definitions.get(reference)).map(DefinitionEntry::getDefinition);
}
public List<SyntaxError> getErrors() {
return errors;
}
public Scope getScope(DefinitionReference reference) {
return tryGetScope(reference).orElseThrow(() -> new IllegalArgumentException("No scope found for reference: " + reference));
}
public List<DefinitionReference> getSortedReferences() {
List<DefinitionReference> references = definitions.keySet().stream()
.filter(reference -> !(reference instanceof ValueReference))
.collect(toList());
List<DefinitionReference> values = definitions.keySet().stream()
.filter(reference -> reference instanceof ValueReference)
.collect(toList());
references.addAll(values);
return references;
}
public Optional<Type> getValue(DefinitionReference reference) {
return getDefinition(reference).map(definition -> definition.asValue()
.map(ValueDefinition::getType)
.orElseGet(def1 -> definition.asSignature()
.map(ValueSignature::getType)
.orElseThrow(def2 -> new IllegalArgumentException("Can't get type of " + definition.getClass().getSimpleName()))));
}
public List<ValueReference> getValues() {
return definitions.keySet().stream()
.filter(reference -> reference instanceof ValueReference)
.map(reference -> (ValueReference) reference)
.collect(toList());
}
public boolean hasErrors() {
return !errors.isEmpty();
}
public DefinitionGraph sort() {
List<SyntaxError> errors = new ArrayList<>();
return copyWith(sort_(errors))
.appendErrors(errors)
.build();
}
public Stream<DefinitionEntry> stream() {
return StreamSupport.stream(spliterator(definitions.values(), definitions.size()), false);
}
public Optional<Scope> tryGetScope(DefinitionReference reference) {
return Optional.ofNullable(definitions.get(reference)).map(DefinitionEntry::getScope);
}
private List<DefinitionEntry> sortValues(List<SyntaxError> errors, List<DefinitionNode> input) {
List<DefinitionNode> roots = new ArrayList<>();
List<DefinitionNode> nodes = new ArrayList<>();
List<DefinitionEntry> output = new ArrayList<>();
input.forEach(node -> {
if (node.hasDependencies()) {
nodes.add(node);
} else {
roots.add(node);
}
});
while (!roots.isEmpty()) {
DefinitionNode root = roots.remove(0);
output.add(root.getEntry());
Iterator<DefinitionNode> iterator = nodes.iterator();
while (iterator.hasNext()) {
DefinitionNode node = iterator.next();
if (node.isDependentOn(root)) {
node.removeDependency(root);
if (!node.hasDependencies()) {
roots.add(node);
iterator.remove();
}
}
}
}
if (nodes.isEmpty()) {
return output;
} else {
errors.add(cyclicDependency(fromNodes(nodes)));
output.addAll(nodes.stream().map(DefinitionNode::getEntry).collect(toList()));
return output;
}
}
private List<DefinitionEntry> sort_(List<SyntaxError> errors) {
List<DefinitionEntry> entries = stream()
.filter(entry -> !(entry.getReference() instanceof ValueReference))
.collect(toList());
List<DefinitionNode> values = stream()
.filter(entry -> entry.getReference() instanceof ValueReference)
.map(DefinitionNode::new)
.collect(toList());
entries.addAll(sortValues(errors, values));
return entries;
}
@AllArgsConstructor(access = PRIVATE)
@EqualsAndHashCode(callSuper = false)
@ToString
public static class CyclicDependencyError extends SyntaxError {
private final DependencyCycle cycle;
@Override
public String prettyPrint() {
return cycle.prettyPrint();
}
@Override
public String report(String indent, int indentLevel) {
return cycle.report(indent, indentLevel);
}
}
public static class DefinitionGraphBuilder {
private final Collection<DefinitionEntry> definitions;
private Optional<SymbolGenerator> optionalSequence;
private Optional<List<SyntaxError>> optionalErrors;
private DefinitionGraphBuilder(Collection<DefinitionEntry> definitions) {
this.definitions = definitions;
this.optionalSequence = Optional.empty();
this.optionalErrors = Optional.empty();
}
public DefinitionGraphBuilder appendErrors(List<SyntaxError> errors) {
optionalErrors = optionalErrors.map(e -> ImmutableList.<SyntaxError>builder().addAll(e).addAll(errors).build());
optionalErrors = Optional.of(optionalErrors.orElseGet(() -> ImmutableList.copyOf(errors)));
return this;
}
public DefinitionGraph build() {
return new DefinitionGraph(definitions, optionalSequence.orElseGet(DefaultSymbolGenerator::new), optionalErrors.orElse(ImmutableList.of()));
}
public DefinitionGraphBuilder withErrors(List<SyntaxError> errors) {
optionalErrors = Optional.of(errors);
return this;
}
public DefinitionGraphBuilder withSequence(SymbolGenerator symbolGenerator) {
optionalSequence = Optional.of(symbolGenerator);
return this;
}
}
}