package de.psi.alloy4smt.ast; import de.psi.alloy4smt.smt.SExpr; import de.psi.alloy4smt.smt.SMTSolver; import edu.mit.csail.sdg.alloy4.*; import edu.mit.csail.sdg.alloy4compiler.ast.*; import edu.mit.csail.sdg.alloy4compiler.translator.*; import kodkod.ast.Formula; import kodkod.ast.Relation; import kodkod.engine.fol2sat.Translation; import kodkod.engine.fol2sat.Translator; import kodkod.engine.fol2sat.TrivialFormulaException; import kodkod.engine.satlab.SATFactory; import kodkod.engine.satlab.SATSolver; import kodkod.instance.Tuple; import kodkod.instance.TupleSet; import java.io.File; import java.io.FileWriter; import java.io.IOException; import java.util.*; public class SmtPreprocessor { public static Result build(Command c, ConstList<Sig> allReachableSigs) throws Err { ConversionInput input = new ConversionInput(c, allReachableSigs); FieldRewritePhase.Result frpr = FieldRewritePhase.run(input); FormulaRewritePhase.Result formr = FormulaRewritePhase.run(frpr); ComputeScopePhase.Result cspr = ComputeScopePhase.run(formr); SmtTranslationPhase.run(cspr); return new Result(formr, cspr); } public static class Result { public final Sig.PrimSig sintref; public final Sig.Field equalsf; public final ConstList<Sig> sigs; public final Command command; public final A4Solution solution; public final ConstList<SExpr<String>> smtExprs; public Result(FormulaRewritePhase.Result formr, ComputeScopePhase.Result csp) { this.sintref = formr.frp.sigSintref; this.equalsf = formr.frp.equalsf; this.sigs = formr.allsigs; this.command = csp.command; this.solution = csp.solution; this.smtExprs = csp.smtExprs; } } private static class ConversionInput { public final Sig.PrimSig sigSint; public final Sig.PrimSig sigSintref; public final Sig.Field aqclass; public final Sig.Field equalsf; public final ConstList<Sig> allReachableSigs; public final int defaultScope; public final Command command; public final String nameSuffix; ConversionInput(Command command, ConstList<Sig> allReachableSigs) { sigSint = (Sig.PrimSig) Helpers.getSigByName(allReachableSigs, "smtint/Sint"); sigSintref = (Sig.PrimSig) Helpers.getSigByName(allReachableSigs, "smtint/SintRef"); aqclass = Helpers.getFieldByName(sigSintref.getFields(), "aqclass"); equalsf = Helpers.getFieldByName(sigSintref.getFields(), "equals"); if (sigSint == null || sigSintref == null || aqclass == null || equalsf == null) throw new AssertionError(); this.allReachableSigs = allReachableSigs; defaultScope = command.overall < 0 ? 1 : command.overall; this.command = command; this.nameSuffix = "_c"; } } private static class FieldRewritePhase { public final ConversionInput in; private final Map<Sig, Sig> newsigmap = new HashMap<Sig, Sig>(); private final Map<Sig.Field, Sig.Field> newfieldmap = new HashMap<Sig.Field, Sig.Field>(); private final ConstList.TempList<Sig> allsigs = new ConstList.TempList<Sig>(); private final ConstList.TempList<FieldRewriter.Result> newrefs = new ConstList.TempList<FieldRewriter.Result>(); public static class Result { public final Sig.PrimSig sigSint; public final Sig.PrimSig sigSintref; public final Sig.Field equalsf; public final Sig.Field aqclass; public final ConstList<Sig> allsigs; public final ConstList<FieldRewriter.Result> sigrefs; public final ConstMap<Sig, Sig> sigmap; public final ConstMap<Sig.Field, Sig.Field> fieldmap; public final ConversionInput input; private Result(ConstList<Sig> allsigs, ConstList<FieldRewriter.Result> sigrefs, ConstMap<Sig, Sig> sigmap, ConstMap<Sig.Field, Sig.Field> fieldmap, ConversionInput input) { this.sigSint = input.sigSint; this.sigSintref = input.sigSintref; this.equalsf = input.equalsf; this.aqclass = input.aqclass; this.allsigs = allsigs; this.sigrefs = sigrefs; this.sigmap = sigmap; this.fieldmap = fieldmap; this.input = input; } public Sig mapSig(Sig old) throws Err { if (old.builtin) return old; Sig res = sigmap.get(old); if (res == null) throw new AssertionError(); return res; } public Sig.Field mapField(Sig.Field old) { Sig.Field field = fieldmap.get(old); if (field == null) throw new AssertionError(); return field; } } private void addSigMapping(Sig oldsig, Sig newsig) { if (oldsig == in.sigSint) throw new AssertionError(); newsigmap.put(oldsig, newsig); allsigs.add(newsig); } public Sig mapSig(Sig old) throws Err { Sig result; if (!newsigmap.containsKey(old)) { if (old.builtin) { result = old; addSigMapping(old, old); } else if (old instanceof Sig.PrimSig) { Attr[] attrs = new Attr[1]; result = new Sig.PrimSig(old.label + in.nameSuffix, old.attributes.toArray(attrs)); addSigMapping(old, result); for (Sig.Field field : old.getFields()) { final FieldRewriter.Result rewriteResult = FieldRewriter.rewrite(this, old, field); final Sig.Field[] newField = result.addTrickyField(field.pos, field.isPrivate, null, null, field.isMeta, new String[] {field.label + in.nameSuffix}, rewriteResult.field); newfieldmap.put(field, newField[0]); if (rewriteResult.ref != null) { newrefs.add(rewriteResult); allsigs.add(rewriteResult.ref); } } } else if (old instanceof Sig.SubsetSig) { throw new AssertionError("not handled yet"); } else { throw new AssertionError(); } } else { result = newsigmap.get(old); } return result; } public Sig.Field mapField(Sig.Field old) { Sig.Field field = newfieldmap.get(old); if (field == null) throw new AssertionError(); return field; } private FieldRewritePhase(ConversionInput in) throws Err { this.in = in; addSigMapping(in.sigSintref, in.sigSintref); for (Sig.Field f : in.sigSintref.getFields()) newfieldmap.put(f, f); for (Sig s : in.allReachableSigs) if (s != in.sigSint) mapSig(s); } public static Result run(ConversionInput in) throws Err { FieldRewritePhase p = new FieldRewritePhase(in); return new Result(p.allsigs.makeConst(), p.newrefs.makeConst(), ConstMap.make(p.newsigmap), ConstMap.make(p.newfieldmap), in); } } private static class FieldRewriter extends VisitReturn<Expr> { private final FieldRewritePhase ctx; private final Sig sig; private final Sig.Field field; private final ConstList.TempList<Type> visitedsigs = new ConstList.TempList<Type>(); private Sig.PrimSig ref = null; public static class Result { public final Sig.PrimSig ref; public final Expr field; public final ConstList<Type> refdeps; public Result(Sig.PrimSig ref, Expr field, ConstList<Type> refdeps) { this.ref = ref; this.field = field; this.refdeps = refdeps; } } public static Result rewrite(FieldRewritePhase ctx, Sig sig, Sig.Field field) throws Err { FieldRewriter rewriter = new FieldRewriter(ctx, sig, field); Expr expr = rewriter.visitThis(field.decl().expr); return new Result(rewriter.ref, expr, rewriter.visitedsigs.makeConst()); } private FieldRewriter(FieldRewritePhase ctx, Sig sig, Sig.Field field) throws Err { this.ctx = ctx; this.sig = sig; this.field = field; visitedsigs.add(ctx.mapSig(sig).type()); } private Expr unexpected() { throw new AssertionError("Unexpected field expression!"); } @Override public Expr visit(ExprList x) throws Err { return unexpected(); } @Override public Expr visit(ExprCall x) throws Err { return unexpected(); } @Override public Expr visit(ExprConstant x) throws Err { return unexpected(); } @Override public Expr visit(ExprITE x) throws Err { return unexpected(); } @Override public Expr visit(ExprLet x) throws Err { return unexpected(); } @Override public Expr visit(ExprQt x) throws Err { return unexpected(); } @Override public Expr visit(ExprVar x) throws Err { return unexpected(); } @Override public Expr visit(ExprUnary x) throws Err { return x.op.make(x.pos, visitThis(x.sub)); } @Override public Expr visit(ExprBinary x) throws Err { // FIXME: Handle cases like A+B -> Sint (compared to A->B->Sint) if (!x.op.isArrow) throw new AssertionError(); return x.op.make(x.pos, x.closingBracket, visitThis(x.left), visitThis(x.right)); } @Override public Expr visit(Sig x) throws Err { Sig s; if (x == ctx.in.sigSint) { if (ref != null) throw new AssertionError(); String label = sig.label + "_" + field.label + "_SintRef"; ref = new Sig.PrimSig(label, ctx.in.sigSintref); s = ref; } else { s = ctx.mapSig(x); visitedsigs.add(s.type()); } return s; } @Override public Expr visit(Sig.Field x) throws Err { return ctx.mapField(x); } } private static class FormulaRewritePhase { public final FieldRewritePhase.Result in; private final ConstList.TempList<SintExprDef> sintExprDefs = new ConstList.TempList<SintExprDef>(); private final ConstList.TempList<SExpr<Sig>> sexprs = new ConstList.TempList<SExpr<Sig>>(); private final List<Sig> allsigs; private final Map<ExprVar, ExprVar> freevarmap = new HashMap<ExprVar, ExprVar>(); private int exprcnt = 0; public static class SintExprDef { public final Sig.PrimSig sig; public final Sig.Field mapField; public final Iterable<Type> dependencies; public SintExprDef(Sig.PrimSig sig, Sig.Field mapField, Iterable<Type> dependencies) { this.sig = sig; this.mapField = mapField; this.dependencies = dependencies; } } public static class Result { public final ConstList<SintExprDef> sintExprDefs; public final ConstList<Sig> allsigs; public final ConstList<SExpr<Sig>> sexprs; public final Expr newformula; public final FieldRewritePhase.Result frp; public Result(ConstList<SintExprDef> sintExprDefs, ConstList<Sig> allsigs, ConstList<SExpr<Sig>> sexprs, Expr newformula, FieldRewritePhase.Result frp) { this.sintExprDefs = sintExprDefs; this.allsigs = allsigs; this.sexprs = sexprs; this.newformula = newformula; this.frp = frp; } } private FormulaRewritePhase(FieldRewritePhase.Result in) { this.in = in; this.allsigs = new Vector<Sig>(in.allsigs); } public static Result run(FieldRewritePhase.Result in) throws Err { FormulaRewritePhase p = new FormulaRewritePhase(in); Expr expr = FormulaRewriter.rewrite(p, in.input.command.formula); if (in.sigSintref.children().isEmpty()) { p.allsigs.remove(in.sigSintref); } return new Result(p.sintExprDefs.makeConst(), ConstList.make(p.allsigs), p.sexprs.makeConst(), expr, in); } public Sig mapSig(Sig old) throws Err { return in.mapSig(old); } public Sig.Field mapField(Sig.Field old) { return in.mapField(old); } public ExprVar mapVar(ExprVar var) { ExprVar result = freevarmap.get(var); if (result == null) throw new AssertionError(); return result; } public void addVarMapping(ExprVar old, ExprVar newvar) { freevarmap.put(old, newvar); } public void addRefSig(Sig.PrimSig ref, Sig.Field mapField, Iterable<Type> dependencies) throws Err { allsigs.add(ref); sintExprDefs.add(new SintExprDef(ref, mapField, dependencies)); } public void addGlobalFact(SExpr<Sig> sexpr) { sexprs.add(sexpr); } public Expr makeRefSig(SExpr<Sig> sexpr) throws Err { StringBuilder sb = new StringBuilder(); sb.append("SintExpr"); sb.append(exprcnt++); Sig.PrimSig ref = new Sig.PrimSig(sb.toString(), in.sigSintref); addRefSig(ref, null, new Vector<Type>()); SExpr<Sig> symb = SExpr.<Sig>leaf(ref); addGlobalFact(SExpr.eq(symb, sexpr)); return ref; } /** * Creates an alias for an arbitrary complex SintRef expression w.r.t. * free variables. * @param expr Alloy Expression which must be of type SintRef * @return A pair (smtvar, subst) where smtvar contains a reference to the * SMT variable as a SExpr. subst is the substitution for expr which * references the newly generated SintExpr signature relation. * @throws Err */ public Pair<SExpr<Sig>, Expr> makeAlias(Expr expr) throws Err { if (!isSintRefExpr(expr)) throw new AssertionError(); final Set<ExprVar> usedquantifiers = FreeVarFinder.find(expr); final List<Type> dependencies = new Vector<Type>(); Sig.PrimSig ref = new Sig.PrimSig("SintExpr" + exprcnt, in.sigSintref); Sig.Field mapField = null; exprcnt++; Expr left; if (usedquantifiers.isEmpty()) { left = ref; } else { Type type = null; for (ExprVar var : usedquantifiers) { if (!var.type().hasArity(1)) throw new AssertionError("Quantified variables with arity > 1 are not supported"); dependencies.add(var.type()); if (type == null) type = var.type(); else type = var.type().product(type); } left = mapField = ref.addField("map", type.toExpr()); for (ExprVar var : usedquantifiers) { left = left.join(var); } } addRefSig(ref, mapField, dependencies); SExpr<Sig> var = SExpr.<Sig>leaf(ref); return new Pair<SExpr<Sig>, Expr>(var, left.join(in.aqclass).equal(expr.join(in.aqclass))); } public boolean isSintRefExpr(Expr expr) { return expr.type().isSubtypeOf(in.sigSintref.type()); } public boolean isSintExpr(Expr expr) { return expr.type().equals(in.sigSint.type()); } } private static class FreeVarFinder extends VisitQuery<Object> { private Set<ExprVar> freeVars = new LinkedHashSet<ExprVar>(); @Override public Object visit(ExprVar x) throws Err { freeVars.add(x); return super.visit(x); } private FreeVarFinder() { } public static Set<ExprVar> find(Expr x) throws Err { FreeVarFinder finder = new FreeVarFinder(); finder.visitThis(x); return finder.freeVars; } } private static class ExprRewriter extends VisitReturn<Expr> { public static Pair<Expr, AndExpr> rewrite(FormulaRewritePhase ctx, Expr expr) throws Err { ExprRewriter rewriter = new ExprRewriter(ctx); return new Pair<Expr, AndExpr>(rewriter.apply(expr), rewriter.result); } private final FormulaRewritePhase ctx; private final AndExpr result = new AndExpr(); private ExprRewriter(FormulaRewritePhase ctx) { this.ctx = ctx; } private Expr apply(Expr expr) throws Err { if (ctx.isSintExpr(expr)) { Pair<Expr, AndExpr> rewrite = SintExprRewriter.rewrite(ctx, expr); result.add(rewrite.b); return rewrite.a; } else { return visitThis(expr); } } private Expr unexpected() { throw new AssertionError("unexpected node"); } @Override public Expr visit(ExprBinary x) throws Err { return x.op.make(x.pos, x.closingBracket, apply(x.left), apply(x.right)); } @Override public Expr visit(ExprUnary x) throws Err { return x.op.make(x.pos, apply(x.sub)); } @Override public Expr visit(ExprITE x) throws Err { return ExprITE.make(x.pos, apply(x.cond), apply(x.left), apply(x.right)); } @Override public Expr visit(ExprList x) throws Err { ConstList.TempList<Expr> args = new ConstList.TempList<Expr>(); for (Expr e: x.args) { args.add(apply(e)); } return ExprList.make(x.pos, x.closingBracket, x.op, args.makeConst()); } @Override public Expr visit(ExprConstant x) throws Err { return x; } @Override public Expr visit(ExprVar x) throws Err { return ctx.mapVar(x); } @Override public Expr visit(ExprLet x) throws Err { return ExprLet.make(x.pos, x.var, apply(x.expr), apply(x.sub)); } @Override public Expr visit(Sig x) throws Err { return ctx.mapSig(x); } @Override public Expr visit(Sig.Field x) throws Err { return ctx.mapField(x); } @Override public Expr visit(ExprCall x) throws Err { ConstList.TempList<Expr> args = new ConstList.TempList<Expr>(); for (Expr e : x.args) { args.add(visitThis(e)); } return ExprCall.make(x.pos, x.closingBracket, x.fun, args.makeConst(), x.extraWeight); } @Override public Expr visit(ExprQt x) throws Err { return unexpected(); } } private static class AndExpr { private final ConstList.TempList<Expr> result = new ConstList.TempList<Expr>(); public void add(Expr expr) { if (expr.equals(ExprConstant.TRUE)) return; result.add(expr); } public void add(AndExpr andExpr) { result.addAll(andExpr.result.makeConst()); } public Expr getExpr() { if (result.size() == 0) return ExprConstant.TRUE; if (result.size() == 1) return result.get(0); Expr last = result.get(result.size() - 1); return ExprList.make(last.pos, last.closingBracket, ExprList.Op.AND, result.makeConst()); } } private static class FormulaRewriter extends VisitReturn<Expr> { public static Expr rewrite(FormulaRewritePhase ctx, Expr formula) throws Err { if (!formula.type().is_bool) throw new AssertionError(); FormulaRewriter rewriter = new FormulaRewriter(ctx); // We don't use `applyFormula` here, because FormulaRewriter is also used by // SintExprRewriter to rewrite subexpressions. return rewriter.visitThis(formula); } private final FormulaRewritePhase ctx; private FormulaRewriter(FormulaRewritePhase ctx) { this.ctx = ctx; } private Expr unexpected() { throw new AssertionError("unexpected node"); } private Expr applyFormula(Expr expr) throws Err { if (!expr.type().is_bool) throw new AssertionError(); return visitThis(expr); } @Override public Expr visit(ExprBinary x) throws Err { Pair<Expr, AndExpr> left = ExprRewriter.rewrite(ctx, x.left); Pair<Expr, AndExpr> right = ExprRewriter.rewrite(ctx, x.right); Expr newx = x.op.make(x.pos, x.closingBracket, left.a, right.a); AndExpr result = new AndExpr(); result.add(left.b); result.add(right.b); result.add(newx); return result.getExpr(); } @Override public Expr visit(ExprUnary x) throws Err { if (x.sub.type().is_bool) return x.op.make(x.pos, applyFormula(x.sub)); else { Pair<Expr, AndExpr> rewritten = ExprRewriter.rewrite(ctx, x.sub); AndExpr result = new AndExpr(); result.add(rewritten.b); result.add(x.op.make(x.pos, rewritten.a)); return result.getExpr(); } } @Override public Expr visit(ExprITE x) throws Err { return ExprITE.make(x.pos, applyFormula(x.cond), applyFormula(x.left), applyFormula(x.right)); } @Override public Expr visit(ExprList x) throws Err { ConstList.TempList<Expr> args = new ConstList.TempList<Expr>(); for (Expr e: x.args) { args.add(applyFormula(e)); } return ExprList.make(x.pos, x.closingBracket, x.op, args.makeConst()); } @Override public Expr visit(ExprConstant x) throws Err { return x; } @Override public Expr visit(ExprVar x) throws Err { return ctx.mapVar(x); } @Override public Expr visit(ExprLet x) throws Err { Pair<Expr, AndExpr> rewritten = ExprRewriter.rewrite(ctx, x.expr); AndExpr result = new AndExpr(); result.add(rewritten.b); result.add(ExprLet.make(x.pos, x.var, rewritten.a, applyFormula(x.sub))); return result.getExpr(); } @Override public Expr visit(Sig x) throws Err { return unexpected(); } @Override public Expr visit(Sig.Field x) throws Err { return unexpected(); } @Override public Expr visit(ExprCall x) throws Err { if (x.fun.label.equals("smtint/gt")) { return SintExprRewriter.rewriteFun(ctx, x, ">"); } else if (x.fun.label.equals("smtint/eq")) { return SintExprRewriter.rewriteFun(ctx, x, "="); } else { ConstList.TempList<Expr> args = new ConstList.TempList<Expr>(); for (Expr e : x.args) { args.add(visitThis(e)); } return ExprCall.make(x.pos, x.closingBracket, x.fun, args.makeConst(), x.extraWeight); } } @Override public Expr visit(ExprQt x) throws Err { AndExpr result = new AndExpr(); ConstList.TempList<Decl> decls = new ConstList.TempList<Decl>(); for (Decl d : x.decls) { Pair<Expr, AndExpr> rewritten = ExprRewriter.rewrite(ctx, d.expr); Expr expr = rewritten.a; result.add(rewritten.b); ConstList.TempList<ExprHasName> names = new ConstList.TempList<ExprHasName>(); for (ExprHasName ehn : d.names) { ExprVar var = ExprVar.make(ehn.pos, ehn.label, expr.type()); ctx.addVarMapping((ExprVar) ehn, var); names.add(var); } decls.add(new Decl(d.isPrivate, d.disjoint, d.disjoint2, names.makeConst(), expr)); } result.add(x.op.make(x.pos, x.closingBracket, decls.makeConst(), rewrite(ctx, x.sub))); return result.getExpr(); } } private static class SintExprRewriter extends VisitReturn<SExpr<Sig>> { public static Pair<Expr, AndExpr> rewrite(FormulaRewritePhase ctx, Expr expr) throws Err { SintExprRewriter rewriter = new SintExprRewriter(ctx); SExpr<Sig> result = rewriter.visitThis(expr); return new Pair<Expr, AndExpr>(ctx.makeRefSig(result).join(ctx.in.aqclass), rewriter.result); } public static Expr rewriteFun(FormulaRewritePhase ctx, ExprCall x, String smtOp) throws Err { SintExprRewriter rewriter = new SintExprRewriter(ctx); ConstList.TempList<SExpr<Sig>> sexprs = new ConstList.TempList<SExpr<Sig>>(); sexprs.add(SExpr.<Sig>sym(smtOp)); for (Expr arg : x.args) { sexprs.add(rewriter.visitThis(arg)); } ctx.addGlobalFact(new SExpr.SList<Sig>(sexprs.makeConst())); return rewriter.result.getExpr(); } private final FormulaRewritePhase ctx; private final AndExpr result = new AndExpr(); private SintExprRewriter(FormulaRewritePhase ctx) { this.ctx = ctx; } private SExpr<Sig> unexpected() { throw new AssertionError("unexpected node"); } @Override public SExpr<Sig> visit(ExprCall x) throws Err { SExpr<Sig> result; if (x.fun.label.equals("smtint/gt")) { result = SExpr.<Sig>call(">", visitThis(x.args.get(0)), visitThis(x.args.get(1))); } else if (x.fun.label.equals("smtint/plus")) { result = SExpr.<Sig>call("+", visitThis(x.args.get(0)), visitThis(x.args.get(1))); } else if (x.fun.label.equals("smtint/const")) { Expr arg = x.args.get(0); int c; if (arg instanceof ExprConstant) c = ((ExprConstant) arg).num(); else if (arg instanceof ExprUnary) { ExprUnary cast = (ExprUnary) arg; if (cast.op != ExprUnary.Op.CAST2SIGINT) throw new AssertionError(); c = ((ExprConstant) cast.sub).num(); } else { throw new AssertionError(); } result = SExpr.<Sig>num(c); } else { throw new AssertionError("User defined Sint functions not yet supported"); } return result; } @Override public SExpr<Sig> visit(ExprBinary x) throws Err { // Relational expression in alloy which results in a Sint, e.g. a . (this/A <: v) Pair<Expr, AndExpr> left = ExprRewriter.rewrite(ctx, x.left); Pair<Expr, AndExpr> right = ExprRewriter.rewrite(ctx, x.right); Pair<SExpr<Sig>, Expr> alias = ctx.makeAlias(x.op.make(x.pos, x.closingBracket, left.a, right.a)); result.add(left.b); result.add(right.b); result.add(alias.b); return alias.a; } @Override public SExpr<Sig> visit(ExprList x) throws Err { return unexpected(); } @Override public SExpr<Sig> visit(ExprConstant x) throws Err { return unexpected(); } @Override public SExpr<Sig> visit(ExprITE x) throws Err { return unexpected(); } @Override public SExpr<Sig> visit(ExprLet x) throws Err { return unexpected(); } @Override public SExpr<Sig> visit(ExprQt x) throws Err { return unexpected(); } @Override public SExpr<Sig> visit(ExprUnary x) throws Err { return unexpected(); } @Override public SExpr<Sig> visit(ExprVar x) throws Err { return unexpected(); } @Override public SExpr<Sig> visit(Sig x) throws Err { return unexpected(); } @Override public SExpr<Sig> visit(Sig.Field x) throws Err { return unexpected(); } } private static class ComputeScopePhase { private final FormulaRewritePhase.Result frpr; private final ConstList.TempList<CommandScope> scopes = new ConstList.TempList<CommandScope>(); private final Map<Sig, CommandScope> scopemap = new HashMap<Sig, CommandScope>(); private final Command command; private final A4Solution solution; private final List<SExpr<String>> sexprs = new Vector<SExpr<String>>(); private final A4Options options; private final Relation equalsrel; private final ConstList.TempList<SExpr.Leaf<String>> smtvars = new ConstList.TempList<SExpr.Leaf<String>>(); private final Map<Sig.PrimSig, List<SExpr.Leaf<String>>> sig2smtvars = new HashMap<Sig.PrimSig, List<SExpr.Leaf<String>>>(); private final Map<Object, SExpr.Leaf<String>> atom2smtvar = new HashMap<Object, SExpr.Leaf<String>>(); private final List<Pair<SExpr.Leaf<String>, SExpr.Leaf<String>>> equalsSmtBounds = new Vector<Pair<SExpr.Leaf<String>, SExpr.Leaf<String>>>(); public static class Result { public final Command command; public final A4Solution solution; public final A4Options options; public final Relation equalsrel; public final ConstList<SExpr<String>> smtExprs; public final ConstList<SExpr.Leaf<String>> smtVars; public final ConstList<Pair<SExpr.Leaf<String>, SExpr.Leaf<String>>> equalsSmtBounds; public Result(Command command, A4Solution solution, A4Options options, Relation equalsrel, ConstList<SExpr<String>> smtExprs, ConstList<SExpr.Leaf<String>> smtVars, ConstList<Pair<SExpr.Leaf<String>, SExpr.Leaf<String>>> equalsSmtBounds) { this.command = command; this.solution = solution; this.options = options; this.equalsrel = equalsrel; this.smtExprs = smtExprs; this.smtVars = smtVars; this.equalsSmtBounds = equalsSmtBounds; } } private void addScope(CommandScope scope) { scopemap.put(scope.sig, scope); scopes.add(scope); } private int computeScope(Iterable<Type> dependencies) throws Err { int result = 1; for (Type type : dependencies) { if (!type.hasArity(1)) throw new AssertionError(); int unionscope = 0; for (List<Sig.PrimSig> l : type.fold()) { if (l.size() != 1) throw new AssertionError(); Sig.PrimSig depsig = l.get(0); CommandScope scope = scopemap.get(depsig); if (scope != null) { unionscope += scope.endingScope; } else if (depsig.isOne != null || depsig.isLone != null) { unionscope++; } else { unionscope += frpr.frp.input.defaultScope; } } result *= unionscope; } return result; } private static A4Options makeA4Options() { final A4Options opt = new A4Options(); opt.recordKodkod = true; opt.tempDirectory = "/tmp"; opt.solverDirectory = "/tmp"; opt.solver = A4Options.SatSolver.SAT4J; opt.skolemDepth = 4; return opt; } private ComputeScopePhase(FormulaRewritePhase.Result in) throws Err { this.frpr = in; final List<Sig.PrimSig> sintrefs = new Vector<Sig.PrimSig>(); // Handle scopes of original signatures for (CommandScope scope : in.frp.input.command.scope) { addScope(new CommandScope(in.frp.mapSig(scope.sig), scope.isExact, scope.endingScope)); } // Handle scopes of SintRef fields for (FieldRewriter.Result rr : in.frp.sigrefs) { sintrefs.add(rr.ref); addScope(new CommandScope(rr.ref, true, computeScope(rr.refdeps))); } // Handle scopes of SintExprs for (FormulaRewritePhase.SintExprDef sed : in.sintExprDefs) { sintrefs.add(sed.sig); addScope(new CommandScope(sed.sig, true, computeScope(sed.dependencies))); } // Build A4Solution command = in.frp.input.command.change(scopes.makeConst()).change(in.newformula); options = makeA4Options(); Pair<A4Solution, ScopeComputer> solsc = ScopeComputer.compute(A4Reporter.NOP, options, in.allsigs, command); solution = solsc.a; BoundsComputer.compute(A4Reporter.NOP, solution, solsc.b, in.allsigs); // Populate SintRef atom -> SMT variable mapping for (Sig.PrimSig sig : sintrefs) { List<SExpr.Leaf<String>> vars = new Vector<SExpr.Leaf<String>>(); for (Object atom : getAtoms(sig)) { SExpr.Leaf<String> var = new SExpr.Leaf<String>(atom.toString().replace("$", "_")); vars.add(var); atom2smtvar.put(atom, var); smtvars.add(var); } sig2smtvars.put(sig, vars); } // set bounds for equals field equalsrel = (Relation) solution.a2k(in.frp.equalsf); if (equalsrel != null) { final List<Object> sintrefAtoms = new Vector<Object>(); for (Sig.PrimSig s : sintrefs) { sintrefAtoms.addAll(getAtoms(s)); } final TupleSet equalsBound = new TupleSet(solution.getBounds().universe(), 2); for (int i = 0; i < sintrefAtoms.size(); ++i) { for (int j = i + 1; j < sintrefAtoms.size(); ++j) { Object atomA = sintrefAtoms.get(i); Object atomB = sintrefAtoms.get(j); SExpr.Leaf<String> varA = atom2smtvar.get(atomA); SExpr.Leaf<String> varB = atom2smtvar.get(atomB); equalsBound.add(solution.getFactory().tuple(atomA, atomB)); equalsSmtBounds.add(new Pair<SExpr.Leaf<String>, SExpr.Leaf<String>>(varA, varB)); } } solution.shrink(equalsrel, new TupleSet(solution.getBounds().universe(), 2), equalsBound); } // set bounds for SintExpr maps for (FormulaRewritePhase.SintExprDef sed : in.sintExprDefs) { if (sed.mapField == null) continue; final List<List<Object>> depAtoms = getDependentAtoms(sed.dependencies); final int depSize = depAtoms.size(); final List<Object[]> sourceTuples = new Vector<Object[]>(); buildMapTupleSet(sourceTuples, depAtoms, new Object[depSize+1], 0); final List<Object> sintExprAtoms = getAtoms(sed.sig); if (sintExprAtoms.size() != sourceTuples.size()) throw new AssertionError(); final Iterator<Object> atomIt = sintExprAtoms.iterator(); final TupleSet mapTuple = new TupleSet(solution.getBounds().universe(), depSize+1); for (Object[] tpl : sourceTuples) { tpl[0] = atomIt.next(); mapTuple.add(solution.getFactory().tuple(tpl)); } final Relation rel = (Relation) solution.a2k(sed.mapField); solution.shrink(rel, mapTuple, mapTuple); } // convert sexpr-sig tree to a sexpr-string tree, which consists of every combination // of atoms of the leaf signature nodes. SExprConverter sec = new SExprConverter(this); for (SExpr<Sig> sexpr : in.sexprs) { sexprs.addAll(sec.visitThis(sexpr)); } } private static class SExprConverter extends SExpr.Visitor<Sig, List<SExpr<String>>> { public SExprConverter(ComputeScopePhase csp) { this.csp = csp; } private final ComputeScopePhase csp; @Override public List<SExpr<String>> visit(SExpr.Symbol<Sig> sigSymbol) { final Vector<SExpr<String>> result = new Vector<SExpr<String>>(); result.add(new SExpr.Symbol<String>(sigSymbol.getName())); return result; } @Override public List<SExpr<String>> visit(SExpr.Leaf<Sig> sigLeaf) { return new Vector<SExpr<String>>(csp.sig2smtvars.get((Sig.PrimSig) sigLeaf.getValue())); } @Override public List<SExpr<String>> visit(SExpr.SList<Sig> sigSList) { final Vector<SExpr<String>> result = new Vector<SExpr<String>>(); final List<List<SExpr<String>>> converted = new Vector<List<SExpr<String>>>(); for (SExpr<Sig> sub : sigSList.getItems()) { converted.add(visitThis(sub)); } build(result, converted, new SExpr[converted.size()], 0); return result; } static void build(List<SExpr<String>> result, List<List<SExpr<String>>> input, SExpr<String>[] selected, int depth) { if (depth == input.size()) { result.add(new SExpr.SList<String>(Arrays.asList(selected.clone()))); } else { for (SExpr<String> expr : input.get(depth)) { selected[depth] = expr; build(result, input, selected, depth + 1); } } } } private static void buildMapTupleSet(List<Object[]> output, List<List<Object>> sourceAtoms, Object[] selected, int depth) { if (depth == sourceAtoms.size()) { output.add(selected.clone()); } else { for (Object obj : sourceAtoms.get(depth)) { selected[depth+1] = obj; buildMapTupleSet(output, sourceAtoms, selected, depth + 1); } } } private List<Object> getAtoms(Sig.PrimSig sig) { final List<Object> result = new Vector<Object>(); final Relation srel = (Relation) solution.a2k(sig); final TupleSet tuples = solution.getBounds().upperBound(srel); if (tuples.arity() != 1) throw new AssertionError(); for (Tuple t : tuples) { result.add(t.atom(0)); } return result; } private List<List<Object>> getDependentAtoms(Iterable<Type> dependencies) { List<List<Object>> result = new Vector<List<Object>>(); for (Type dep : dependencies) { if (!dep.hasArity(1)) throw new AssertionError(); for (List<Sig.PrimSig> l : dep.fold()) { if (l.size() != 1) throw new AssertionError(); result.add(getAtoms(l.get(0))); } } return result; } public static Result run(FormulaRewritePhase.Result in) throws Err { ComputeScopePhase p = new ComputeScopePhase(in); return new Result(p.command, p.solution, p.options, p.equalsrel, ConstList.make(p.sexprs), p.smtvars.makeConst(), ConstList.make(p.equalsSmtBounds)); } } private static class SmtTranslationPhase extends TranslateAlloyToKodkod { protected SmtTranslationPhase(A4Reporter rep, A4Options opt, A4Solution frame, Command cmd) { super(rep, opt, frame, cmd); } public static void run(ComputeScopePhase.Result csp) throws Err { final SMTSolver solver = new SMTSolver(); csp.solution.solver.options().setSolver(new SATFactory() { @Override public SATSolver instance() { return solver; } @Override public boolean prover() { return false; } @Override public boolean minimizer() { return false; } @Override public boolean incremental() { return false; } @Override public String toString() { return "SMT Backend"; } }); SmtTranslationPhase stp = new SmtTranslationPhase(A4Reporter.NOP, csp.options, csp.solution, csp.command); stp.makeFacts(csp.command.formula); final Formula kformula = stp.frame.makeFormula(A4Reporter.NOP, new Simplifier()); final Translation tl; try { tl = Translator.translate(kformula, stp.frame.getBounds(), stp.frame.solver.options()); } catch (TrivialFormulaException e) { e.printStackTrace(); throw new ErrorFatal(e.toString()); } for (SExpr.Leaf<String> var : csp.smtVars) { solver.addIntVariable(var.getValue()); } if (csp.equalsrel != null) { int[] relvars = tl.primaryVariables(csp.equalsrel).toArray(); for (int i = 0; i < relvars.length; ++i) { final Pair<SExpr.Leaf<String>, SExpr.Leaf<String>> pair = csp.equalsSmtBounds.get(i); solver.addEquality(relvars[i], SExpr.<String>eq(pair.a, pair.b)); } } kodkodDebug(csp, stp, kformula, solver); } private static void kodkodDebug(ComputeScopePhase.Result csp, SmtTranslationPhase stp, Formula kformula, SMTSolver solver) { // KODKOD DEBUG OUTPUT List<String> kkatoms = new Vector<String>(); for (Object atom : stp.frame.getFactory().universe()) { kkatoms.add((String) atom); } String kkout = TranslateKodkodToJava.convert(kformula, stp.frame.getBitwidth(), kkatoms, stp.frame.getBounds(), null); try { File tmpout = File.createTempFile("kodkodout", ".txt"); FileWriter writer = new FileWriter(tmpout); writer.write(csp.command.formula.toString()); writer.write("\n=======================================\n"); writer.write(kkout); writer.write("\n=======================================\n"); writer.write(solver.makeSMTFormula().toString()); writer.close(); } catch (IOException e) { e.printStackTrace(); } } } }