package de.psi.alloy4smt.ast; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.Vector; import edu.mit.csail.sdg.alloy4.ConstList; import edu.mit.csail.sdg.alloy4.ConstList.TempList; import edu.mit.csail.sdg.alloy4.Err; import edu.mit.csail.sdg.alloy4.ErrorFatal; import edu.mit.csail.sdg.alloy4.ErrorSyntax; import edu.mit.csail.sdg.alloy4.Pair; import edu.mit.csail.sdg.alloy4compiler.ast.Attr; import edu.mit.csail.sdg.alloy4compiler.ast.Command; import edu.mit.csail.sdg.alloy4compiler.ast.CommandScope; import edu.mit.csail.sdg.alloy4compiler.ast.Decl; import edu.mit.csail.sdg.alloy4compiler.ast.Expr; import edu.mit.csail.sdg.alloy4compiler.ast.ExprBinary; import edu.mit.csail.sdg.alloy4compiler.ast.ExprCall; import edu.mit.csail.sdg.alloy4compiler.ast.ExprConstant; import edu.mit.csail.sdg.alloy4compiler.ast.ExprHasName; import edu.mit.csail.sdg.alloy4compiler.ast.ExprITE; import edu.mit.csail.sdg.alloy4compiler.ast.ExprLet; import edu.mit.csail.sdg.alloy4compiler.ast.ExprList; import edu.mit.csail.sdg.alloy4compiler.ast.ExprQt; import edu.mit.csail.sdg.alloy4compiler.ast.ExprUnary; import edu.mit.csail.sdg.alloy4compiler.ast.ExprVar; import edu.mit.csail.sdg.alloy4compiler.ast.Sig; import edu.mit.csail.sdg.alloy4compiler.ast.Sig.Field; import edu.mit.csail.sdg.alloy4compiler.ast.Sig.PrimSig; import edu.mit.csail.sdg.alloy4compiler.ast.Type; import edu.mit.csail.sdg.alloy4compiler.ast.VisitQuery; import edu.mit.csail.sdg.alloy4compiler.ast.VisitReturn; import edu.mit.csail.sdg.alloy4compiler.parser.CompModule; public class IntRefPreprocessor { public final Sig.PrimSig intref; public final ConstList<PreparedCommand> commands; public final ConstList<Sig> sigs; public static String atomize(Sig sig, int id) { String label = sig.label; if (label.startsWith("this/")) { label = label.substring(5); } return label + "$" + id; } private IntRefPreprocessor(Computer computer) throws Err { intref = computer.intref; sigs = computer.sigs.makeConst(); TempList<PreparedCommand> tmpCommands = new TempList<PreparedCommand>(); for (int i = 0; i < computer.commands.size(); ++i) { final IntexprSigBuilder isbuilder = new IntexprSigBuilder(computer.commands.get(i), intref); final FactRewriter.Result rewriteres = FactRewriter.rewrite(computer.commands.get(i).formula, isbuilder); final Command command = isbuilder.getModifiedCommand().change(rewriteres.newformula); final TempList<PreparedCommand.IntrefSigRecord> l = new TempList<PreparedCommand.IntrefSigRecord>(); l.addAll(computer.intrefRecords.get(i)); l.addAll(isbuilder.getIntexprRecords()); final TempList<Sig> esigs = new TempList<Sig>(); esigs.addAll(sigs); esigs.addAll(isbuilder.getIntExprSigs()); tmpCommands.add(new PreparedCommand(command, rewriteres.hysatexprs, l.makeConst(), esigs.makeConst(), intref)); } commands = tmpCommands.makeConst(); } private IntRefPreprocessor(CompModule module) { intref = null; sigs = module.getAllReachableSigs(); TempList<PreparedCommand> tmpCommands = new TempList<PreparedCommand>(); for (Command c : module.getAllCommands()) { tmpCommands.add(new PreparedCommand(c, sigs, null, null)); } commands = tmpCommands.makeConst(); } private static class Computer { private ConstList<Command> oldcommands; private Map<Command, List<CommandScope>> newscopes; private Map<Command, TempList<PreparedCommand.IntrefSigRecord>> tmpIntrefRecords; private Map<PrimSig, PrimSig> old2newsigs; private Map<Field, Field> old2newfields; public TempList<Sig> sigs; public Sig.PrimSig intref; public TempList<Command> commands; public TempList<ConstList<PreparedCommand.IntrefSigRecord>> intrefRecords; public Computer(CompModule module, Sig.PrimSig intref) throws Err { this.intref = intref; sigs = new TempList<Sig>(); old2newsigs = new HashMap<PrimSig, PrimSig>(); old2newfields = new HashMap<Field, Field>(); oldcommands = module.getAllCommands(); commands = new TempList<Command>(); newscopes = new HashMap<Command, List<CommandScope>>(); tmpIntrefRecords = new HashMap<Command, TempList<PreparedCommand.IntrefSigRecord>>(); intrefRecords = new TempList<ConstList<PreparedCommand.IntrefSigRecord>>(); for (Command c: oldcommands) { newscopes.put(c, new Vector<CommandScope>()); tmpIntrefRecords.put(c, new TempList<PreparedCommand.IntrefSigRecord>()); } // Initialize old2newsigs map. All Sigs except builtin and intref are cloned. // old2newsigs is a map which gets used for the later step convertSig // TODO: What about Subsetsigs? for (Sig s: module.getAllReachableSigs()) { if (s == intref) { old2newsigs.put(intref, intref); for (Field f : s.getFields()) old2newfields.put(f, f); } else if (isSmtInt(s)) { Attr[] attrs = new Attr[1]; PrimSig newsig = new PrimSig(s.label, s.attributes.toArray(attrs)); old2newsigs.put((PrimSig) s, newsig); } } // Converts all Sigs. The mapping between old and new sigs is already in place. // Now the new Sigs have to be populated with data/fields. for (Sig s: module.getAllReachableSigs()) { if (s.builtin || s == intref) { sigs.add(s); } else { sigs.add(convertSig(s)); } } // Rewrite step of the facts. Each Sig and its fields have a new instance. Now the facts have // to get updated and added to the new Sigs. for (Sig s: module.getAllReachableSigs()) { if (!s.builtin) convertSigFacts(s); } // Rewrite of the formula generated by each command. for (Command c: oldcommands) { TempList<CommandScope> scopes = new TempList<CommandScope>(); for (CommandScope scope : c.scope) { scopes.add(new CommandScope(old2newsigs.get(scope.sig), scope.isExact, scope.endingScope)); } scopes.addAll(newscopes.get(c)); ExprRewriter rewriter = new ExprRewriter(intref, old2newsigs, old2newfields); commands.add(c.change(scopes.makeConst()).change(rewriter.visitThis(c.formula))); intrefRecords.add(tmpIntrefRecords.get(c).makeConst()); } } private boolean isSmtInt(Sig s) { return !s.builtin && s instanceof PrimSig; //return !s.builtin && s.label.equals("Sint"); } private Sig convertSig(Sig sig) throws Err { final Sig newSig = old2newsigs.get(sig); for (Sig.Field field: sig.getFields()) { final ExprRewriter rewriter = new ExprRewriter(intref, sig, field, old2newsigs, old2newfields); final Expr newExpr = rewriter.visitThis(field.decl().expr); final Field[] newField = newSig.addTrickyField( field.pos, field.isPrivate, null, null, field.isMeta, new String[] { field.label }, newExpr); old2newfields.put(field, newField[0]); for (PrimSig primSig: rewriter.newintrefs) { for (Command c: oldcommands) { final int scope = rewriter.getFactor(c); newscopes.get(c).add(new CommandScope(primSig, true, scope)); tmpIntrefRecords.get(c).add(new PreparedCommand.IntrefSigRecord(primSig, null, null, scope)); } sigs.add(primSig); } } return newSig; } private void convertSigFacts(Sig sig) throws Err { final Sig newSig = old2newsigs.get(sig); for (Expr fact : sig.getFacts()) { newSig.addFact(new ExprRewriter(intref, old2newsigs, old2newfields).visitThis(fact)); } } private static class ExprRewriter extends VisitReturn<Expr> { private final PrimSig intref; private final Sig currentSig; private final Field currentField; private final Map<PrimSig, PrimSig> old2newsigs; private final Map<Field, Field> old2newfields; private Field lastfield = null; public List<PrimSig> newintrefs; public List<Sig> factors; ExprRewriter(PrimSig intref, Map<PrimSig, PrimSig> old2newsigs, Map<Field, Field> old2newfields) { this.intref = intref; this.currentSig = null; this.currentField = null; this.old2newsigs = old2newsigs; this.old2newfields = old2newfields; this.newintrefs = null; this.factors = new Vector<Sig>(); } ExprRewriter(PrimSig intref, Sig currentSig, Field currentField, Map<PrimSig, PrimSig> old2newsigs, Map<Field, Field> old2newfields) { this.intref = intref; this.currentSig = currentSig; this.currentField = currentField; this.old2newsigs = old2newsigs; this.old2newfields = old2newfields; this.newintrefs = new Vector<PrimSig>(); this.factors = new Vector<Sig>(); addFactor(currentSig); } private Sig makeSig() throws Err { if (lastfield != currentField) { lastfield = currentField; } else { throw new ErrorFatal(currentField.pos, "unsupported decl"); } String label = currentSig.label + "_" + currentField.label + "_IntRef"; PrimSig sig = new Sig.PrimSig(label, intref); newintrefs.add(sig); return sig; } private void addFactor(Sig x) { this.factors.add(x); } public int getFactor(Command c) { int result = 1; for (Sig s : factors) { result *= Helpers.getScope(c, s); } return result; } @Override public Expr visit(ExprBinary x) throws Err { return x.op.make(x.pos, x.closingBracket, visitThis(x.left), visitThis(x.right)); } @Override public Expr visit(ExprList x) throws Err { List<Expr> args = new Vector<Expr>(); for (Expr a : x.args) args.add(visitThis(a)); return ExprList.make(x.pos, x.closingBracket, x.op, args); } @Override public Expr visit(ExprCall x) throws Err { return x; } @Override public Expr visit(ExprConstant x) throws Err { return x; } @Override public Expr visit(ExprITE x) throws Err { return ExprITE.make(x.pos, visitThis(x.cond), visitThis(x.left), visitThis(x.right)); } @Override public Expr visit(ExprLet x) throws Err { return ExprLet.make(x.pos, x.var, visitThis(x.expr), visitThis(x.sub)); } @Override public Expr visit(ExprQt x) throws Err { List<Decl> decls = new Vector<Decl>(); for (Decl d : x.decls) { decls.add(new Decl(d.isPrivate, d.disjoint, d.disjoint2, d.names, visitThis(d.expr))); } return x.op.make(x.pos, x.closingBracket, decls, visitThis(x.sub)); } @Override public Expr visit(ExprUnary x) throws Err { return x.op.make(x.pos, visitThis(x.sub)); } @Override public Expr visit(ExprVar x) throws Err { return x; } @Override public Expr visit(Sig x) throws Err { Sig s; if (x == Sig.SIGINT) { s = makeSig(); } else { s = old2newsigs.get(x); if (s == null) s = x; addFactor(x); } return s; } @Override public Expr visit(Field x) throws Err { Field f = old2newfields.get(x); if (f == null) throw new AssertionError(); return f; } } } private static class IntexprSigBuilder { private LinkedHashMap<ExprVar, Expr> freeVars; private Context ctx; private static class Context { public int id = 0; public PrimSig intref; public Sig.Field aqclass; public List<PreparedCommand.IntrefSigRecord> records = new Vector<PreparedCommand.IntrefSigRecord>(); public Command command; } public IntexprSigBuilder(Command command, PrimSig intref) { ctx = new Context(); ctx.command = command; ctx.intref = intref; ctx.aqclass = Helpers.getFieldByName(intref.getFields(), "aqclass"); freeVars = new LinkedHashMap<ExprVar, Expr>(); } private IntexprSigBuilder(IntexprSigBuilder other) { ctx = other.ctx; freeVars = new LinkedHashMap<ExprVar, Expr>(other.freeVars); } public Pair<PreparedCommand.IntrefSigRecord, Expr> make(Expr intrefExpr) throws Err { final PrimSig intexprsig = new PrimSig("IntExpr" + ctx.id++, ctx.intref); final Set<ExprVar> usedFreeVars = FreeVarFinder.find(intrefExpr); final Expr right = ExprBinary.Op.JOIN.make(null, null, intrefExpr, ctx.aqclass); Expr left; Sig.Field mapfield = null; List<Sig> dependencies = new Vector<Sig>(); int instances = 1; if (!usedFreeVars.isEmpty()) { Type type = null; for (ExprVar var : usedFreeVars) { final Expr e = freeVars.get(var); if (type == null) { type = e.type(); } else { type = e.type().product(type); } } mapfield = intexprsig.addField("map", type.toExpr()); for (List<PrimSig> ss : type.fold()) { int ssinst = 1; for (PrimSig sig : ss) { ssinst *= Helpers.getScope(ctx.command, sig); dependencies.add(0, sig); } instances *= ssinst; } Expr mapjoin = mapfield; for (ExprVar var : usedFreeVars) { mapjoin = ExprBinary.Op.JOIN.make(null, null, mapjoin, var); } left = ExprBinary.Op.JOIN.make(null, null, mapjoin, ctx.aqclass); } else { left = ExprBinary.Op.JOIN.make(null, null, intexprsig, ctx.aqclass); } PreparedCommand.IntrefSigRecord result = new PreparedCommand.IntrefSigRecord(intexprsig, mapfield, ConstList.make(dependencies), instances); ctx.command = ctx.command.change(intexprsig, true, instances); ctx.records.add(result); return new Pair<PreparedCommand.IntrefSigRecord, Expr>(result, ExprBinary.Op.EQUALS.make(null, null, left, right)); } public IntexprSigBuilder addFreeVariables(ConstList<Decl> decls) { IntexprSigBuilder result = new IntexprSigBuilder(this); for (Decl d : decls) { for (ExprHasName ehn : d.names) { result.freeVars.put((ExprVar) ehn, d.expr); } } return result; } public Command getModifiedCommand() { return ctx.command; } public List<PreparedCommand.IntrefSigRecord> getIntexprRecords() { return ctx.records; } public List<PrimSig> getIntExprSigs() { List<PrimSig> result = new Vector<PrimSig>(); for (PreparedCommand.IntrefSigRecord record : ctx.records) { result.add(record.sig); } return result; } } private static class FreeVarFinder extends VisitQuery<Object> { private Set<ExprVar> freeVars = new LinkedHashSet<ExprVar>(); @Override public Object visit(ExprVar x) throws Err { freeVars.add(x); return super.visit(x); } private FreeVarFinder() { } public static Set<ExprVar> find(Expr x) throws Err { FreeVarFinder finder = new FreeVarFinder(); finder.visitThis(x); return finder.freeVars; } } private static class IntExprHandler extends VisitReturn<TempList<String>> { private List<Expr> facts; private IntexprSigBuilder builder; private boolean cast2intSeen; public IntExprHandler(IntexprSigBuilder isb) { this.facts = new Vector<Expr>(); this.builder = isb; this.cast2intSeen = false; } public Expr getFacts() { return ExprList.make(null, null, ExprList.Op.AND, facts); } private void throwUnsupportedOperator(Expr x) throws Err { throw new ErrorSyntax(x.pos, "HySAT does not support this operator."); } @Override public TempList<String> visit(ExprUnary x) throws Err { TempList<String> result; if (x.op == ExprUnary.Op.CAST2INT) { cast2intSeen = true; result = visitThis(x.sub); cast2intSeen = false; } else { final TempList<String> sub = visitThis(x.sub); switch (x.op) { case NOOP: result = sub; break; default: throw new AssertionError(); } } return result; } @Override public TempList<String> visit(ExprBinary x) throws Err { if (cast2intSeen && x.op == ExprBinary.Op.JOIN) { Pair<PreparedCommand.IntrefSigRecord, Expr> result = builder.make(x); facts.add(result.b); return new TempList<String>(result.a.getAtoms()); } else { final TempList<String> left = visitThis(x.left); final TempList<String> right = visitThis(x.right); String op = null; switch (x.op) { case GT: op = ">"; break; case LT: op = "<"; break; case GTE: op = ">="; break; case LTE: op = "<="; break; case EQUALS: op = "="; break; case MINUS: op = "-"; break; case MUL: op = "*"; break; case NOT_EQUALS: op = "!="; break; case NOT_GT: op = "<="; break; case NOT_LT: op = ">="; break; case NOT_GTE: op = "<"; break; case NOT_LTE: op = ">"; break; case PLUS: op = "+"; break; default: throwUnsupportedOperator(x); } TempList<String> result = new TempList<String>(); for (String l : left.makeConst()) { for (String r : right.makeConst()) { result.add("(" + l + " " + op + " " + r + ")"); } } return result; } } @Override public TempList<String> visit(ExprConstant x) throws Err { if (x.op == ExprConstant.Op.NUMBER) { return new TempList<String>(String.valueOf(x.num)); } else { throw new ErrorSyntax(x.pos, "Constant not convertible to HySAT"); } } @Override public TempList<String> visit(ExprList x) throws Err { throw new AssertionError(); } @Override public TempList<String> visit(ExprCall x) throws Err { throw new AssertionError(); } @Override public TempList<String> visit(ExprITE x) throws Err { throw new AssertionError(); } @Override public TempList<String> visit(ExprLet x) throws Err { throw new AssertionError(); } @Override public TempList<String> visit(ExprQt x) throws Err { throw new AssertionError(); } @Override public TempList<String> visit(ExprVar x) throws Err { throw new AssertionError(); } @Override public TempList<String> visit(Sig x) throws Err { throw new AssertionError(); } @Override public TempList<String> visit(Field x) throws Err { throw new AssertionError(); } } private static class FactRewriter extends VisitReturn<Expr> { public static class Result { public final Expr newformula; public final ConstList<String> hysatexprs; public Result(Expr newformula, ConstList<String> hysatexprs) { this.newformula = newformula; this.hysatexprs = hysatexprs; } } private IntexprSigBuilder intexprBuilder; private TempList<String> hysatexprs; private FactRewriter(IntexprSigBuilder builder) { hysatexprs = new TempList<String>(); intexprBuilder = builder; } public static Result rewrite(Expr expr, IntexprSigBuilder builder) throws Err { FactRewriter rewriter = new FactRewriter(builder); Expr rewritten = rewriter.visitThis(expr); return new Result(rewritten, rewriter.hysatexprs.makeConst()); } @Override public Expr visit(ExprBinary x) throws Err { Expr result = null; // TODO: change this to check for our custom SMTInt datatype if (x.left.type().is_int() && x.right.type().is_int()) { final IntExprHandler ieh = new IntExprHandler(intexprBuilder); final TempList<String> hexpr = ieh.visitThis(x); hysatexprs.addAll(hexpr.makeConst()); result = ieh.getFacts(); } else { final Expr left = visitThis(x.left); final Expr right = visitThis(x.right); result = x.op.make(x.pos, x.closingBracket, left, right); } return result; } @Override public Expr visit(ExprList x) throws Err { TempList<Expr> args = new TempList<Expr>(); for (Expr e: x.args) { args.add(visitThis(e)); } return ExprList.make(x.pos, x.closingBracket, x.op, args.makeConst()); } @Override public Expr visit(ExprCall x) throws Err { TempList<Expr> args = new TempList<Expr>(); for (Expr e: x.args) { args.add(visitThis(e)); } return ExprCall.make(x.pos, x.closingBracket, x.fun, args.makeConst(), x.extraWeight); } @Override public Expr visit(ExprConstant x) throws Err { return x; } @Override public Expr visit(ExprITE x) throws Err { Expr cond = visitThis(x.cond); Expr left = visitThis(x.left); Expr right = visitThis(x.right); return ExprITE.make(x.pos, cond, left, right); } @Override public Expr visit(ExprLet x) throws Err { Expr sub = visitThis(x.sub); return ExprLet.make(x.pos, x.var, x.expr, sub); } @Override public Expr visit(ExprQt x) throws Err { final IntexprSigBuilder tmpold = intexprBuilder; intexprBuilder = tmpold.addFreeVariables(x.decls); Expr sub = visitThis(x.sub); intexprBuilder = tmpold; return x.op.make(x.pos, x.closingBracket, x.decls, sub); } @Override public Expr visit(ExprUnary x) throws Err { final Expr sub = visitThis(x.sub); if (x.op == ExprUnary.Op.CAST2INT) { throw new AssertionError(); } else { return x.op.make(x.pos, sub); } } @Override public Expr visit(ExprVar x) throws Err { return x; } @Override public Expr visit(Sig x) throws Err { return x; } @Override public Expr visit(Field x) throws Err { return x; } } public static IntRefPreprocessor processModule(CompModule module) throws Err { final Sig.PrimSig intref = (Sig.PrimSig) Helpers.getSigByName(module.getAllReachableSigs(), "intref/IntRef"); if (intref != null) { final Computer computer = new Computer(module, intref); return new IntRefPreprocessor(computer); } else { return new IntRefPreprocessor(module); } } }