package de.psi.alloy4smt.ast; import edu.mit.csail.sdg.alloy4.Err; import edu.mit.csail.sdg.alloy4compiler.ast.Command; import edu.mit.csail.sdg.alloy4compiler.ast.Sig; import edu.mit.csail.sdg.alloy4compiler.parser.CompModule; import edu.mit.csail.sdg.alloy4compiler.parser.CompUtil; import edu.mit.csail.sdg.alloy4compiler.translator.A4Options; import kodkod.ast.Relation; import kodkod.instance.TupleSet; import org.junit.Before; import org.junit.Test; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Vector; import static org.junit.Assert.*; public class SmtPreprocessorTest { public static final String docPrelude = "open util/smtint\n"; public static final String smtintFacts = "(all a,b | AND[" + "OR[" + "b in a . (smtint/SintRef <: equals), " + "a in b . (smtint/SintRef <: equals)" + "] <=> a . (smtint/SintRef <: aqclass) = b . (smtint/SintRef <: aqclass), " + "b in a . (smtint/SintRef <: equals) => OR[" + "b . (smtint/SintRef <: aqclass) = a, " + "b . (smtint/SintRef <: aqclass) in a . (smtint/SintRef <: equals)" + "]" + "])"; private CompModule module; private List<SmtPreprocessor.Result> commands; @Before public void setUp() { module = null; commands = null; } private void parseModule(String doc) throws Err { Map<String,String> fm = new HashMap<String, String>(); fm.put("/tmp/x", docPrelude + doc); module = CompUtil.parseEverything_fromFile(null, fm, "/tmp/x"); assertTrue(module.getAllCommands().size() > 0); commands = new Vector<SmtPreprocessor.Result>(); for (Command c : module.getAllCommands()) { commands.add(SmtPreprocessor.build(c, module.getAllReachableSigs())); } } private Sig getOldSig(String name) { return Helpers.getSigByName(module.getAllReachableSigs(), name); } private Sig.Field getOldField(String signame, String field) { Sig sig = getOldSig(signame); return Helpers.getFieldByName(sig.getFields(), field); } private Sig getNewSig(String name) { return Helpers.getSigByName(commands.get(0).sigs, name); } private Sig.Field getNewField(String signame, String field) { Sig sig = getNewSig(signame); return Helpers.getFieldByName(sig.getFields(), field); } private void assertFields(String signame, String fieldname, String oldrepr, String newrepr) { Sig.Field oldf = getOldField(signame, fieldname); Sig.Field newf = getNewField(signame + "_c", fieldname + "_c"); assertEquals(oldrepr, oldf.decl().expr.toString()); assertEquals(newrepr, newf.decl().expr.toString()); } @Test public void simpleTest() throws Err { parseModule( "sig X { v: Y }\n" + "sig Y { w: X ->one X }\n" + "pred show {}\n" + "run show for 2 X, 2 Y\n"); assertFields("this/X", "v", "one this/Y", "one this/Y_c"); assertFields("this/Y", "w", "this/X ->one this/X", "this/X_c ->one this/X_c"); } private void assertDeclConversion(String decl, String newDecl) throws Err { parseModule("sig A {}\nsig X { v: " + decl + " }\npred show{}\nrun show for 2 A, 2 X\n"); assertFields("this/X", "v", decl, newDecl); } @Test public void oneIntegerTest() throws Err { assertDeclConversion("one this/A", "one this/A_c"); assertDeclConversion("one smtint/Sint", "one this/X_v_SintRef"); assertDeclConversion("univ ->one smtint/Sint", "univ ->one this/X_v_SintRef"); assertDeclConversion("this/A ->one smtint/Sint", "this/A_c ->one this/X_v_SintRef"); assertDeclConversion("this/A ->lone smtint/Sint", "this/A_c ->lone this/X_v_SintRef"); } @Test public void intRefBounds() throws Err { parseModule( "sig X { v: Sint }\n" + "pred show{}\n" + "run show for 4 X\n"); assertEquals("Run show for 4 X", module.getAllCommands().get(0).toString()); assertEquals("Run show for 4 X_c, exactly 4 X_v_SintRef", commands.get(0).command.toString()); } @Test public void intRefBounds2() throws Err { parseModule( "sig X { v: Y ->one Sint, w: X -> Y ->one Sint }\n" + "sig Y {}\n" + "pred show {}\n" + "run show for 4 X, 3 Y\n" ); assertEquals("Run show for 4 X, 3 Y", module.getAllCommands().get(0).toString()); assertEquals("Run show for 4 X_c, 3 Y_c, exactly 12 X_v_SintRef, exactly 48 X_w_SintRef", commands.get(0).command.toString()); } @Test public void intRefBounds3() throws Err { parseModule( "sig X { v: Y ->one Sint -> Y }\n" + "sig Y {}\n" + "pred show {}\n" + "run show for 3 X, 4 Y\n" ); assertEquals("Run show for 3 X, 4 Y", module.getAllCommands().get(0).toString()); assertEquals("Run show for 3 X_c, 4 Y_c, exactly 48 X_v_SintRef", commands.get(0).command.toString()); } @Test public void oneSigBounds() throws Err { parseModule( "one sig X { u: Sint, v: Y ->one Sint, w: Z ->one Sint }\n" + "sig Y {}\n" + "one sig Z {}\n" + "pred show {}\n" + "run show for 4 Y\n"); assertEquals("Run show for 4 Y", module.getAllCommands().get(0).toString()); assertEquals("Run show for 4 Y_c, exactly 1 X_u_SintRef, exactly 4 X_v_SintRef, exactly 1 X_w_SintRef", commands.get(0).command.toString()); Sig sigXold = Helpers.getSigByName(module.getAllReachableSigs(), "this/X"); Sig sigXnew = Helpers.getSigByName(commands.get(0).sigs, "this/X_c"); Sig sigYold = Helpers.getSigByName(module.getAllReachableSigs(), "this/Y"); Sig sigYnew = Helpers.getSigByName(commands.get(0).sigs, "this/Y_c"); assertNotNull(sigXold.isOne); assertNotNull(sigXnew.isOne); assertNull(sigYold.isOne); assertNull(sigYnew.isOne); } @Test public void implicitSigBounds() throws Err { parseModule( "sig X { v: Y ->one Sint }\n" + "sig Y {}\n" + "one sig Z { u: Y ->one Sint }\n" + "pred show {}\n" + "run show for 3 but 4 Y\n" + "run show for 3\n"); assertEquals("Run show for 3 but 4 Y", module.getAllCommands().get(0).toString()); assertEquals("Run show for 3 but 4 Y_c, exactly 12 X_v_SintRef, exactly 4 Z_u_SintRef", commands.get(0).command.toString()); assertEquals("Run show for 3", module.getAllCommands().get(1).toString()); assertEquals("Run show for 3 but exactly 9 X_v_SintRef, exactly 3 Z_u_SintRef", commands.get(1).command.toString()); } /* @Test public void unchangedFacts() throws Err { parseModule( "sig A {}\n" + "sig B { m: A -> A}\n" + "pred testeq[a: A, b: B] { let a2 = b.m[a] | a2 != a }\n" + "fact { all b: B, a: A { let a2 = a | testeq[a2, b] } }\n" + "fact { all b: B, a: A { b.m[a] = a implies b.m[a] = a else b.m[a] != a } }\n" + "pred show {}\n" + "run show for 4"); assertEquals(module.getAllReachableFacts().toString(), commands.get(0).command.formula.toString()); } */ private static A4Options makeA4Options() { final A4Options opt = new A4Options(); opt.recordKodkod = true; opt.tempDirectory = "/tmp"; opt.solverDirectory = "/tmp"; opt.solver = A4Options.SatSolver.SAT4J; opt.skolemDepth = 4; return opt; } private void assertEqualsTupleSet(String tuplesetstr) { final SmtPreprocessor.Result command = commands.get(0); final Relation rel = (Relation) command.solution.a2k(command.equalsf); final TupleSet lb = command.solution.getBounds().lowerBound(rel); final TupleSet ub = command.solution.getBounds().upperBound(rel); assertEquals("[]", lb.toString()); assertEquals(tuplesetstr, ub.toString()); } private void assertSintexprBounds(int exprid, String tuplesetstr) { final SmtPreprocessor.Result command = commands.get(0); final Sig.PrimSig sintexpr = (Sig.PrimSig) Helpers.getSigByName(command.sigs, "SintExpr" + exprid); final Sig.Field sintmap = Helpers.getFieldByName(sintexpr.getFields(), "map"); final Relation rel = (Relation) command.solution.a2k(sintmap); final TupleSet lb = command.solution.getBounds().lowerBound(rel); final TupleSet ub = command.solution.getBounds().upperBound(rel); assertEquals(tuplesetstr, lb.toString()); assertEquals(tuplesetstr, ub.toString()); } @Test public void rewriteFactsAndExtractIntExprs() throws Err { parseModule( "one sig A { v: Sint }\n" + "fact { A.v.plus[const[2]] = const[4] }\n" + "fact { A.v.gt[const[0]] }\n" + "pred show {}\n" + "run show for 3\n"); assertEquals( "AND[" + "smtint/plus[this/A . (this/A <: v), smtint/const[Int[2]]] = smtint/const[Int[4]], " + "smtint/gt[this/A . (this/A <: v), smtint/const[Int[0]]], " + smtintFacts + "]", module.getAllReachableFacts().toString()); assertEquals("[smtint/SintRef, univ, Int, seq/Int, String, none, this/A_c, this/A_v_SintRef, " + "SintExpr0, SintExpr1, SintExpr2, SintExpr3]", commands.get(0).sigs.toString()); assertEquals("Run show for 3 but exactly 1 A_v_SintRef, exactly 1 SintExpr0, exactly 1 SintExpr1, exactly 1 SintExpr2, exactly 1 SintExpr3", commands.get(0).command.toString()); assertEquals( "AND[" + "SintExpr0 . (smtint/SintRef <: aqclass) = this/A_c . (this/A_c <: v_c) . (smtint/SintRef <: aqclass), " + "SintExpr1 . (smtint/SintRef <: aqclass) = SintExpr2 . (smtint/SintRef <: aqclass), " + "SintExpr3 . (smtint/SintRef <: aqclass) = this/A_c . (this/A_c <: v_c) . (smtint/SintRef <: aqclass), " + smtintFacts + "]", commands.get(0).command.formula.toString()); assertEqualsTupleSet("[" + "[A_v_SintRef$0, SintExpr0$0], " + "[A_v_SintRef$0, SintExpr1$0], " + "[A_v_SintRef$0, SintExpr2$0], " + "[A_v_SintRef$0, SintExpr3$0], " + "[SintExpr0$0, SintExpr1$0], " + "[SintExpr0$0, SintExpr2$0], " + "[SintExpr0$0, SintExpr3$0], " + "[SintExpr1$0, SintExpr2$0], " + "[SintExpr1$0, SintExpr3$0], " + "[SintExpr2$0, SintExpr3$0]" + "]"); assertEquals("[(= SintExpr1_0 (+ SintExpr0_0 2)), (= SintExpr2_0 4), (> SintExpr3_0 0)]", commands.get(0).smtExprs.toString()); } @Test public void rewriteFactsAndExtractIntExprsInQuantifiedFormula() throws Err { parseModule( "sig A { v: Sint }\n" + "fact { all a: A | a.v.plus[const[2]].eq[const[4]] }\n" + "pred show {}\n" + "run show for 3 A\n"); assertEquals("AND[" + "(all a | smtint/eq[smtint/plus[a . (this/A <: v), smtint/const[Int[2]]], smtint/const[Int[4]]]), " + smtintFacts + "]", module.getAllReachableFacts().toString()); assertEquals("AND[" + "(all a | (SintExpr0 <: map) . a . (smtint/SintRef <: aqclass) = " + "a . (this/A_c <: v_c) . (smtint/SintRef <: aqclass)), " + smtintFacts + "]", commands.get(0).command.formula.toString()); assertEquals("Run show for 3 A_c, exactly 3 A_v_SintRef, exactly 3 SintExpr0", commands.get(0).command.toString()); assertSintexprBounds(0, "[" + "[SintExpr0$0, A_c$0], " + "[SintExpr0$1, A_c$1], " + "[SintExpr0$2, A_c$2]" + "]"); assertEquals("[(= (+ SintExpr0_0 2) 4), (= (+ SintExpr0_1 2) 4), (= (+ SintExpr0_2 2) 4)]", commands.get(0).smtExprs.toString()); } }