/* * Copyright 2003-2011 JetBrains s.r.o. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package jetbrains.mps.newTypesystem.state; import gnu.trove.THashMap; import gnu.trove.THashSet; import jetbrains.mps.newTypesystem.TypesUtil; import jetbrains.mps.newTypesystem.operation.block.RemoveBlockOperation; import jetbrains.mps.newTypesystem.relations.AbstractRelation; import jetbrains.mps.newTypesystem.relations.ComparableRelation; import jetbrains.mps.newTypesystem.relations.SubTypingRelation; import jetbrains.mps.newTypesystem.state.blocks.*; import jetbrains.mps.smodel.SNodeId; import org.jetbrains.mps.openapi.model.SNode; import jetbrains.mps.util.containers.ManyToManyMap; import jetbrains.mps.util.Pair; import java.util.*; public class Inequalities { private final State myState; private ManyToManyMap<SNode, SNode> myInputsToOutputsInc = new ManyToManyMap<SNode, SNode>(); private ManyToManyMap<SNode, RelationBlock> myNodesToBlocksInc = new ManyToManyMap<SNode, RelationBlock>(); private Set<SNode> myNodesInc = new THashSet<SNode>(); private Set<SNode> mySolvableLeft = new THashSet<SNode>(); private Set<SNode> mySolvableRight = new THashSet<SNode>(); private boolean mySolveOnlyRight = false; private static final ComparableRelation comparableRelation = new ComparableRelation(); private static final SubTypingRelation subTypingRelation = new SubTypingRelation(); private boolean solvingInProcess = false; public void setSolvingInProcess(boolean solvingInProcess) { this.solvingInProcess = solvingInProcess; } public boolean isSolvingInProcess() { return solvingInProcess; } public Inequalities(State state) { myState = state; } protected State getState() { return myState; } public void printAll() { System.out.println("Relations"); for (Block node : getRelationsToSolve()) { System.out.println(node.getExpandedPresentation(myState)); } } private void printMMMap(ManyToManyMap<SNode, SNode> map) { for (SNode node :map.getFirst()) { System.out.print(node + " <- " ); for (SNode second : map.getByFirst(node)) { System.out.print(" "+ second); } System.out.println(); } } private SNode getNodeWithNoInput(Iterable<SNode> sorted, Set<SNode> used) { for (SNode node : sorted) { if (used.containsAll(myInputsToOutputsInc.getBySecond(node))) { return node; } } //if no absolutely independent nodes - than try more complicated way SNode minNode = null; for (SNode node : sorted) { if (isIndependent(used, node)) { return node; } if (minNode == null) { minNode = node; } } //otherwise choose by name to be deterministic return minNode; } private boolean isIndependent(Set<SNode> used, SNode var) { Queue<SNode> dependsOn = new LinkedList<SNode>(); Set<SNode> passed = new HashSet<SNode>(); dependsOn.addAll(myInputsToOutputsInc.getBySecond(var)); while (!dependsOn.isEmpty()) { SNode node = dependsOn.remove(); if (used.contains(node) || passed.contains(node)) { continue; } passed.add(node); if (!mySolvableLeft.contains(node) && !mySolvableRight.contains(node)) { dependsOn.addAll(myInputsToOutputsInc.getBySecond(node)); continue; } return false; } return true; } public List<RelationBlock> getRelationsToSolve() { List<RelationBlock> result = new LinkedList<RelationBlock>(); for (Block block : myState.getBlocks()) { if (block.getBlockKind() != BlockKind.WHEN_CONCRETE && block.getBlockKind() != BlockKind.TARGET) { RelationBlock relationBlock = (RelationBlock) block; if (!relationBlock.isCheckOnly()) { result.add(relationBlock); } } } return result; } public void solveRelations() { solvingInProcess = true; List<RelationBlock> inequalities = getRelationsToSolve(); initializeMapsInc(inequalities); while (iteration(inequalities)) { inequalities = getRelationsToSolve(); } solvingInProcess = false; } private void addVariablesLinkInc(SNode input, SNode output) { if (!TypesUtil.isVariable(input)) return; if (!TypesUtil.isVariable(output)) return; if (input == output) return; myInputsToOutputsInc.addLink(input, output); } private void initializeMapsInc(List<RelationBlock> inequalities) { myInputsToOutputsInc.clear(); myNodesToBlocksInc.clear(); myNodesInc.clear(); mySolvableLeft.clear(); mySolvableRight.clear(); for (RelationBlock inequality : inequalities) { onInequalityAdded(inequality); } } private void substituteVarInSet(SNode oldVar, SNode newVar, Set<SNode> varSet) { varSet.remove(oldVar); if (TypesUtil.isVariable(newVar)) { myNodesInc.add(newVar); } } public void onEquationAdded(SNode child, SNode parent) { if (!solvingInProcess) return; for (RelationBlock block : new ArrayList<RelationBlock>(myNodesToBlocksInc.getByFirst(child))) { myNodesToBlocksInc.removeLink(child, block); if (TypesUtil.isVariable(parent)) { myNodesToBlocksInc.addLink(parent, block); } } substituteVarInSet(child, parent, myNodesInc); substituteVarInSet(child, parent, mySolvableLeft); substituteVarInSet(child, parent, mySolvableRight); List<SNode> variables = TypesUtil.getVariables(parent, myState); for (SNode outputVar : new ArrayList<SNode>(myInputsToOutputsInc.getByFirst(child))) { for (SNode inputVar : variables) { addVariablesLinkInc(inputVar, outputVar); } myInputsToOutputsInc.removeLink(child, outputVar); } for (SNode inputVar : new ArrayList<SNode>(myInputsToOutputsInc.getBySecond(child))) { for (SNode outputVar : variables) { addVariablesLinkInc(inputVar, outputVar); } myInputsToOutputsInc.removeLink(inputVar, child); } } public void onInequalityAdded(RelationBlock inequality) { if (!solvingInProcess) return; if (inequality.isCheckOnly()) { return; } for (Pair<SNode, SNode> pair : inequality.getInputsAndOutputs()) { SNode input = myState.getRepresentative(pair.o1); SNode output = myState.getRepresentative(pair.o2); if (input == null || output == null) continue; final List<SNode> invars = TypesUtil.getVariables(input, myState); for (SNode inputVar : invars) { if (TypesUtil.isVariable(inputVar)) { myNodesInc.add(inputVar); myNodesToBlocksInc.addLink(inputVar, inequality); } } final List<SNode> outvars = TypesUtil.getVariables(output, myState); for (SNode outputVar : outvars) { if (TypesUtil.isVariable(outputVar)) { myNodesInc.add(outputVar); myNodesToBlocksInc.addLink(outputVar, inequality); } } if (input != output) { for (SNode inputVar : invars) { for (SNode outputVar : outvars) { addVariablesLinkInc(myState.getRepresentative(inputVar), myState.getRepresentative(outputVar)); } } } } SNode left = myState.getRepresentative(inequality.getLeftNode()); SNode right = myState.getRepresentative(inequality.getRightNode()); if (TypesUtil.isVariable(left)) { mySolvableLeft.add(left); } if (TypesUtil.isVariable(right)) { mySolvableRight.add(right); } } private boolean chooseVarAndSolve(Set<SNode> nodes) { //Solves relation for an independent node //first tries to solve for when concrete waiting node if (nodes.isEmpty()) return false; for (Block block : myState.getBlocks(BlockKind.WHEN_CONCRETE)) { SNode node = myState.getRepresentative(((WhenConcreteBlock) block).getArgument()); for (SNode var : TypesUtil.getVariables(node, myState)) { if (nodes.contains(var) && myInputsToOutputsInc.getBySecond(var).isEmpty()) { if (solveRelationsForNode(var)) { return true; } } } } Set<SNode> usedNodes = new HashSet<SNode>(); LinkedList<SNode> tempNodes = new LinkedList<SNode>(nodes); // sort once to avoid n^2 Collections.sort(tempNodes, new Comparator<SNode>() { @Override public int compare(SNode a, SNode b) { return ((SNodeId)a.getNodeId()).compareTo((SNodeId)b.getNodeId()); } }); while (tempNodes.size() > 0) { SNode current = getNodeWithNoInput(tempNodes, usedNodes); if (solveRelationsForNode(current)) { return true; } tempNodes.remove(current); usedNodes.add(current); } return false; } protected boolean iteration(List<RelationBlock> inequalities) { if (myNodesInc.size() == 0) { return false; } mySolveOnlyRight = true; if (chooseVarAndSolve(mySolvableRight)) return true; mySolveOnlyRight = false; if (chooseVarAndSolve(mySolvableLeft)) return true; // recursive relations have to be eliminated *before* we attempt to pick a type for a var // but this slows down inequations elimination substantially if (trySolvingRecursive(inequalities)) return true; if (lastChance(inequalities)) return true; return false; } private boolean isRecursive(RelationBlock inequality) { if (TypesUtil.isVariable(inequality.getLeftNode()) || !TypesUtil.isVariable(inequality.getRightNode())) return false; final SNode rightRep = myState.getRepresentative(inequality.getRightNode()); if (!TypesUtil.isVariable(rightRep)) return false; final List<SNode> leftVars = TypesUtil.getVariables(inequality.getLeftNode(), myState); if (leftVars.isEmpty()) return false; return leftVars.contains(rightRep); } private boolean trySolvingRecursive(List<RelationBlock> inequalities) { for (RelationBlock inequality : inequalities) { if (isRecursive(inequality) && myState.getBlocks().contains(inequality)) { myState.executeOperation(new RemoveBlockOperation(inequality)); return true; } } return false; } private boolean lastChance(List<RelationBlock> inequalities) { for (RelationBlock inequality : inequalities) { if (!(TypesUtil.isVariable(inequality.getLeftNode()) && TypesUtil.isVariable(inequality.getRightNode())) && myState.getBlocks().contains(inequality)) { myState.executeOperation(new RemoveBlockOperation(inequality)); return true; } } return false; } private void collectNodesTransitive(SNode node, Set<SNode> collected, boolean isLeft, Map<SNode, RelationBlock> typesToBlocks, AbstractRelation relation, Set<SNode> alreadyPassed) { // Patching a deficiency of this algorithm: we're listening to equation/inequation adding, but not removing // TODO: update the incremental maps on equation/inequation removal Set<RelationBlock> blocks = new THashSet<RelationBlock>(myNodesToBlocksInc.getByFirst(node)); final Set<Block> stateBlocks = myState.getBlocks(); for(Iterator<RelationBlock> it = blocks.iterator(); it.hasNext();) { final RelationBlock next = it.next(); if(!stateBlocks.contains(next) || isRecursive(next)) { // recursive relations are solved at the end it.remove(); } } alreadyPassed.add(node); blocks = getRelationBlocks(blocks, relation); for (RelationBlock block : blocks) { if (block.isCheckOnly()) { continue; } SNode left = myState.getRepresentative(block.getLeftNode()); SNode right = myState.getRepresentative(block.getRightNode()); if (right == left) { continue; } SNode cur = isLeft ? left : right; SNode other = isLeft ? right : left; if (cur == node) { if (!TypesUtil.isVariable(other)) { SNode type = myState.expand(other); collected.add(type); typesToBlocks.put(type, block); } else { if (!alreadyPassed.contains(other)){ collectNodesTransitive(other, collected, isLeft, typesToBlocks, relation, alreadyPassed); } } } } } private boolean solveRelationsForNode(SNode node) { if (solveRelationForNode(node, subTypingRelation)) { return true; } if (!mySolveOnlyRight) { return solveRelationForNode(node, comparableRelation); } return false; } private Set<RelationBlock> getRelationBlocks(Set<RelationBlock> blocks, AbstractRelation relation) { Set<RelationBlock> result = new THashSet<RelationBlock>(); for (RelationBlock block : blocks) { if (relation.accept(block.getRelationKind())) { result.add(block); } } return result; } private boolean solveRelationForNode(SNode node, AbstractRelation relation) { Map<SNode, RelationBlock> typesToBlocks = new THashMap<SNode, RelationBlock>(); assert TypesUtil.isVariable(node); Set<SNode> rightTypes = new LinkedHashSet<SNode>(); Set<SNode> leftTypes = new LinkedHashSet<SNode>(); collectNodesTransitive(node, leftTypes, false, typesToBlocks, relation, new HashSet<SNode>()); if (!mySolveOnlyRight) { collectNodesTransitive(node, rightTypes, true, typesToBlocks, relation, new HashSet<SNode>()); } return relation.solve(node, leftTypes, rightTypes, myState, typesToBlocks); } public Map<Set<SNode>, Set<InequalityBlock>> getInequalityGroups(Set<Block> inequalities) { Map<SNode, Set<SNode>> components = new HashMap<SNode, Set<SNode>>(1); Map<Set<SNode>, Set<InequalityBlock>> groupsToInequalities = new HashMap<Set<SNode>, Set<InequalityBlock>>(); Set<SNode> emptySet = new HashSet<SNode>(1); for (Block block : inequalities) { InequalityBlock inequality = (InequalityBlock) block; List<SNode> variables = TypesUtil.getVariables(inequality.getRightNode(), myState); variables.addAll(TypesUtil.getVariables(inequality.getLeftNode(), myState)); if (variables.size() == 0) { Set<InequalityBlock> emptyBlocks = groupsToInequalities.get(emptySet); if (emptyBlocks == null) { emptyBlocks = new HashSet<InequalityBlock>(1); groupsToInequalities.put(emptySet, emptyBlocks); } emptyBlocks.add(inequality); continue; } Set<SNode> currentResult = new HashSet<SNode>(); Set<InequalityBlock> currentInequalities = new HashSet<InequalityBlock>(); currentInequalities.add(inequality); for (SNode var : variables) { var = myState.getRepresentative(var); currentResult.add(var); Set<SNode> varGroup = components.remove(var); if (varGroup != null) { currentResult.addAll(varGroup); for (SNode var2 : varGroup) { if (!variables.contains(var2)) { components.put(var2, currentResult); } } } components.put(var, currentResult); Set<InequalityBlock> remove = groupsToInequalities.remove(varGroup); if (remove != null) { currentInequalities.addAll(remove); } } groupsToInequalities.put(currentResult, currentInequalities); } return groupsToInequalities; } public void clear() { } }