/* * Copyright 2015 Red Hat, Inc. and/or its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.drools.compiler.integrationtests; import org.drools.core.base.ClassObjectType; import org.drools.core.common.InternalWorkingMemory; import org.drools.core.impl.KnowledgeBaseImpl; import org.drools.core.reteoo.BetaMemory; import org.drools.core.reteoo.ConditionalBranchNode; import org.drools.core.reteoo.InitialFactImpl; import org.drools.core.reteoo.JoinNode; import org.drools.core.reteoo.LeftInputAdapterNode; import org.drools.core.reteoo.LeftInputAdapterNode.LiaNodeMemory; import org.drools.core.reteoo.NotNode; import org.drools.core.reteoo.ObjectTypeNode; import org.drools.core.reteoo.PathMemory; import org.drools.core.reteoo.RightInputAdapterNode; import org.drools.core.reteoo.RuleTerminalNode; import org.drools.core.reteoo.SegmentMemory; import org.junit.Test; import org.kie.api.io.ResourceType; import org.kie.api.runtime.rule.FactHandle; import org.kie.internal.KnowledgeBase; import org.kie.internal.KnowledgeBaseFactory; import org.kie.internal.builder.KnowledgeBuilder; import org.kie.internal.builder.KnowledgeBuilderFactory; import org.kie.internal.io.ResourceFactory; import java.util.List; import static org.junit.Assert.*; public class SegmentCreationTest { @Test public void testSingleEmptyLhs() throws Exception { KnowledgeBase kbase = buildKnowledgeBase(" "); InternalWorkingMemory wm = ((InternalWorkingMemory)kbase.newStatefulKnowledgeSession()); ObjectTypeNode aotn = getObjectTypeNode(kbase, InitialFactImpl.class ); LeftInputAdapterNode liaNode = (LeftInputAdapterNode) aotn.getObjectSinkPropagator().getSinks()[0]; RuleTerminalNode rtn = ( RuleTerminalNode) liaNode.getSinkPropagator().getSinks()[0]; wm.insert( new LinkingTest.A() ); wm.flushPropagations(); // LiaNode and Rule are in same segment LiaNodeMemory liaMem = ( LiaNodeMemory ) wm.getNodeMemory( liaNode ); SegmentMemory smem = liaMem.getSegmentMemory(); assertEquals( liaNode, smem.getRootNode() ); assertEquals( rtn, smem.getTipNode() ); assertNull( smem.getNext() ); assertNull( smem.getFirst() ); } @Test public void testSingleSharedEmptyLhs() throws Exception { KnowledgeBase kbase = buildKnowledgeBase( " ", " "); InternalWorkingMemory wm = ((InternalWorkingMemory)kbase.newStatefulKnowledgeSession()); ObjectTypeNode aotn = getObjectTypeNode(kbase, InitialFactImpl.class ); LeftInputAdapterNode liaNode = (LeftInputAdapterNode) aotn.getObjectSinkPropagator().getSinks()[0]; RuleTerminalNode rtn1 = ( RuleTerminalNode) liaNode.getSinkPropagator().getSinks()[0]; RuleTerminalNode rtn2 = ( RuleTerminalNode) liaNode.getSinkPropagator().getSinks()[1]; wm.insert( new LinkingTest.A() ); wm.flushPropagations(); // LiaNode is in it's own segment LiaNodeMemory liaMem = ( LiaNodeMemory ) wm.getNodeMemory( liaNode ); SegmentMemory smem = liaMem.getSegmentMemory(); assertEquals( liaNode, smem.getRootNode() ); assertEquals( liaNode, smem.getTipNode() ); // each RTN is in it's own segment SegmentMemory rtnSmem1 = smem.getFirst(); assertEquals( rtn1, rtnSmem1.getRootNode() ); assertEquals( rtn1, rtnSmem1.getTipNode() ); SegmentMemory rtnSmem2 = rtnSmem1.getNext(); assertEquals( rtn2, rtnSmem2.getRootNode() ); assertEquals( rtn2, rtnSmem2.getTipNode() ); } @Test public void testSinglePattern() throws Exception { KnowledgeBase kbase = buildKnowledgeBase(" A() \n"); InternalWorkingMemory wm = ((InternalWorkingMemory)kbase.newStatefulKnowledgeSession()); ObjectTypeNode aotn = getObjectTypeNode(kbase, LinkingTest.A.class ); LeftInputAdapterNode liaNode = (LeftInputAdapterNode) aotn.getObjectSinkPropagator().getSinks()[0]; RuleTerminalNode rtn = ( RuleTerminalNode) liaNode.getSinkPropagator().getSinks()[0]; wm.insert(new LinkingTest.A()); wm.flushPropagations(); // LiaNode and Rule are in same segment LiaNodeMemory liaMem = ( LiaNodeMemory ) wm.getNodeMemory( liaNode ); SegmentMemory smem = liaMem.getSegmentMemory(); assertEquals( liaNode, smem.getRootNode() ); assertEquals( rtn, smem.getTipNode() ); assertNull( smem.getNext() ); assertNull( smem.getFirst() ); } @Test public void testSingleSharedPattern() throws Exception { KnowledgeBase kbase = buildKnowledgeBase( " A() \n", " A() \n"); InternalWorkingMemory wm = ((InternalWorkingMemory)kbase.newStatefulKnowledgeSession()); ObjectTypeNode aotn = getObjectTypeNode(kbase, LinkingTest.A.class ); LeftInputAdapterNode liaNode = (LeftInputAdapterNode) aotn.getObjectSinkPropagator().getSinks()[0]; RuleTerminalNode rtn1 = ( RuleTerminalNode) liaNode.getSinkPropagator().getSinks()[0]; RuleTerminalNode rtn2 = ( RuleTerminalNode) liaNode.getSinkPropagator().getSinks()[1]; wm.insert(new LinkingTest.A()); wm.flushPropagations(); // LiaNode is in it's own segment LiaNodeMemory liaMem = ( LiaNodeMemory ) wm.getNodeMemory( liaNode ); SegmentMemory smem = liaMem.getSegmentMemory(); assertEquals( liaNode, smem.getRootNode() ); assertEquals( liaNode, smem.getTipNode() ); // each RTN is in it's own segment SegmentMemory rtnSmem1 = smem.getFirst(); assertEquals( rtn1, rtnSmem1.getRootNode() ); assertEquals( rtn1, rtnSmem1.getTipNode() ); SegmentMemory rtnSmem2 = rtnSmem1.getNext(); assertEquals( rtn2, rtnSmem2.getRootNode() ); assertEquals( rtn2, rtnSmem2.getTipNode() ); } @Test public void testMultiSharedPattern() throws Exception { KnowledgeBase kbase = buildKnowledgeBase( " A() \n", " A() B() \n", " A() B() C() \n"); InternalWorkingMemory wm = ((InternalWorkingMemory)kbase.newStatefulKnowledgeSession()); ObjectTypeNode aotn = getObjectTypeNode(kbase, LinkingTest.A.class ); LeftInputAdapterNode liaNode = (LeftInputAdapterNode) aotn.getObjectSinkPropagator().getSinks()[0]; RuleTerminalNode rtn1 = ( RuleTerminalNode) liaNode.getSinkPropagator().getSinks()[0]; JoinNode bNode = ( JoinNode ) liaNode.getSinkPropagator().getSinks()[1]; RuleTerminalNode rtn2 = ( RuleTerminalNode) bNode.getSinkPropagator().getSinks()[0]; JoinNode cNode = ( JoinNode ) bNode.getSinkPropagator().getSinks()[1]; RuleTerminalNode rtn3 = ( RuleTerminalNode) cNode.getSinkPropagator().getSinks()[0]; wm.insert( new LinkingTest.A() ); wm.insert( new LinkingTest.B() ); wm.insert(new LinkingTest.C()); wm.flushPropagations(); // LiaNode is in it's own segment LiaNodeMemory liaMem = ( LiaNodeMemory ) wm.getNodeMemory( liaNode ); SegmentMemory smem = liaMem.getSegmentMemory(); assertEquals( liaNode, smem.getRootNode() ); assertEquals( liaNode, smem.getTipNode() ); SegmentMemory rtnSmem1 = smem.getFirst(); assertEquals( rtn1, rtnSmem1.getRootNode() ); assertEquals( rtn1, rtnSmem1.getTipNode() ); SegmentMemory bSmem = rtnSmem1.getNext(); assertEquals( bNode, bSmem.getRootNode() ); assertEquals( bNode, bSmem.getTipNode() ); // child segment is not yet initialised, so null assertNull( bSmem.getFirst() ); // there is no next assertNull( bSmem.getNext() ); wm.fireAllRules(); // child segments should now be initialised wm.flushPropagations(); SegmentMemory rtnSmem2 = bSmem.getFirst(); assertEquals( rtn2, rtnSmem2.getRootNode() ); assertEquals( rtn2, rtnSmem2.getTipNode() ); SegmentMemory cSmem = rtnSmem2.getNext(); assertEquals( cNode, cSmem.getRootNode() ); assertEquals( rtn3, cSmem.getTipNode() ); // note rtn3 is in the same segment as C } @Test public void testSubnetworkNoSharing() throws Exception { KnowledgeBase kbase = buildKnowledgeBase( " A() not ( B() and C() ) \n" ); InternalWorkingMemory wm = ((InternalWorkingMemory)kbase.newStatefulKnowledgeSession()); ObjectTypeNode aotn = getObjectTypeNode(kbase, LinkingTest.A.class ); LeftInputAdapterNode liaNode = (LeftInputAdapterNode) aotn.getObjectSinkPropagator().getSinks()[0]; JoinNode bNode = ( JoinNode ) liaNode.getSinkPropagator().getSinks()[0]; JoinNode cNode = ( JoinNode ) bNode.getSinkPropagator().getSinks()[0]; RightInputAdapterNode riaNode = ( RightInputAdapterNode ) cNode.getSinkPropagator().getSinks()[0]; NotNode notNode = ( NotNode ) liaNode.getSinkPropagator().getSinks()[1]; RuleTerminalNode rtn1 = ( RuleTerminalNode) notNode.getSinkPropagator().getSinks()[0]; wm.insert( new LinkingTest.A() ); wm.insert( new LinkingTest.B() ); wm.insert( new LinkingTest.C() ); wm.flushPropagations(); // LiaNode is in it's own segment LiaNodeMemory liaMem = ( LiaNodeMemory ) wm.getNodeMemory( liaNode ); SegmentMemory smem = liaMem.getSegmentMemory(); assertEquals( liaNode, smem.getRootNode() ); assertEquals( liaNode, smem.getTipNode() ); assertNull( smem.getNext() ); smem = smem.getFirst(); SegmentMemory bSmem = wm.getNodeMemory( bNode ).getSegmentMemory(); // it's nested inside of smem, so lookup from wm assertEquals( smem, bSmem ); assertEquals( bNode, bSmem.getRootNode() ); assertEquals( riaNode, bSmem.getTipNode() ); BetaMemory bm = ( BetaMemory ) wm.getNodeMemory( notNode ); assertEquals( bm.getSegmentMemory(), smem.getNext() ); assertEquals(bSmem, bm.getRiaRuleMemory().getSegmentMemory() ); // check subnetwork ref was made } @Test public void tesSubnetworkAfterShare() throws Exception { KnowledgeBase kbase = buildKnowledgeBase( " A() \n", " A() not ( B() and C() ) \n" ); InternalWorkingMemory wm = ((InternalWorkingMemory)kbase.newStatefulKnowledgeSession()); ObjectTypeNode aotn = getObjectTypeNode(kbase, LinkingTest.A.class ); LeftInputAdapterNode liaNode = (LeftInputAdapterNode) aotn.getObjectSinkPropagator().getSinks()[0]; RuleTerminalNode rtn1 = ( RuleTerminalNode) liaNode.getSinkPropagator().getSinks()[0]; JoinNode bNode = ( JoinNode ) liaNode.getSinkPropagator().getSinks()[1]; JoinNode cNode = ( JoinNode ) bNode.getSinkPropagator().getSinks()[0]; RightInputAdapterNode riaNode = ( RightInputAdapterNode ) cNode.getSinkPropagator().getSinks()[0]; NotNode notNode = ( NotNode ) liaNode.getSinkPropagator().getSinks()[2]; RuleTerminalNode rtn2 = ( RuleTerminalNode) notNode.getSinkPropagator().getSinks()[0]; wm.insert( new LinkingTest.A() ); wm.insert( new LinkingTest.B() ); wm.insert( new LinkingTest.C() ); wm.flushPropagations(); // LiaNode is in it's own segment LiaNodeMemory liaMem = ( LiaNodeMemory ) wm.getNodeMemory( liaNode ); SegmentMemory smem = liaMem.getSegmentMemory(); assertEquals( liaNode, smem.getRootNode() ); assertEquals( liaNode, smem.getTipNode() ); SegmentMemory rtnSmem1 = smem.getFirst(); assertEquals( rtn1, rtnSmem1.getRootNode() ); assertEquals( rtn1, rtnSmem1.getTipNode() ); SegmentMemory bSmem = rtnSmem1.getNext(); assertEquals( bNode, bSmem.getRootNode() ); assertEquals( riaNode, bSmem.getTipNode() ); SegmentMemory notSmem = bSmem.getNext(); assertEquals( notNode, notSmem.getRootNode() ); assertEquals( rtn2, notSmem.getTipNode() ); // child segment is not yet initialised, so null assertNull( bSmem.getFirst() ); } @Test public void tesShareInSubnetwork() throws Exception { KnowledgeBase kbase = buildKnowledgeBase( " A() \n", " A() B() C() \n", " A() not ( B() and C() ) \n" ); InternalWorkingMemory wm = ((InternalWorkingMemory)kbase.newStatefulKnowledgeSession()); ObjectTypeNode aotn = getObjectTypeNode(kbase, LinkingTest.A.class ); LeftInputAdapterNode liaNode = (LeftInputAdapterNode) aotn.getObjectSinkPropagator().getSinks()[0]; RuleTerminalNode rtn1 = ( RuleTerminalNode) liaNode.getSinkPropagator().getSinks()[0]; JoinNode bNode = ( JoinNode ) liaNode.getSinkPropagator().getSinks()[1]; JoinNode cNode = ( JoinNode ) bNode.getSinkPropagator().getSinks()[0]; RuleTerminalNode rtn2 = ( RuleTerminalNode ) cNode.getSinkPropagator().getSinks()[0]; RightInputAdapterNode riaNode = ( RightInputAdapterNode ) cNode.getSinkPropagator().getSinks()[1]; NotNode notNode = ( NotNode ) liaNode.getSinkPropagator().getSinks()[2]; RuleTerminalNode rtn3 = ( RuleTerminalNode) notNode.getSinkPropagator().getSinks()[0]; wm.insert( new LinkingTest.A() ); wm.insert( new LinkingTest.B() ); wm.insert( new LinkingTest.C() ); wm.flushPropagations(); // LiaNode is in it's own segment LiaNodeMemory liaMem = ( LiaNodeMemory ) wm.getNodeMemory( liaNode ); SegmentMemory smem = liaMem.getSegmentMemory(); assertEquals( liaNode, smem.getRootNode() ); assertEquals( liaNode, smem.getTipNode() ); SegmentMemory rtnSmem1 = smem.getFirst(); assertEquals( rtn1, rtnSmem1.getRootNode() ); assertEquals( rtn1, rtnSmem1.getTipNode() ); SegmentMemory bSmem = rtnSmem1.getNext(); assertEquals( bNode, bSmem.getRootNode() ); assertEquals( cNode, bSmem.getTipNode() ); assertNull( bSmem.getFirst() ); // segment is not initialized yet wm.fireAllRules(); SegmentMemory rtn2Smem = bSmem.getFirst(); assertEquals( rtn2, rtn2Smem.getRootNode() ); assertEquals( rtn2, rtn2Smem.getTipNode() ); SegmentMemory riaSmem = rtn2Smem.getNext(); assertEquals( riaNode, riaSmem.getRootNode() ); assertEquals( riaNode, riaSmem.getTipNode() ); SegmentMemory notSmem = bSmem.getNext(); assertEquals( notNode, notSmem.getRootNode() ); assertEquals( rtn3, notSmem.getTipNode() ); } @Test public void testBranchCESingleSegment() throws Exception { KnowledgeBase kbase = buildKnowledgeBase( " $a : A() \n" + " if ( $a != null ) do[t1] \n" + " B() \n" ); InternalWorkingMemory wm = ((InternalWorkingMemory)kbase.newStatefulKnowledgeSession()); ObjectTypeNode aotn = getObjectTypeNode(kbase, LinkingTest.A.class ); LeftInputAdapterNode liaNode = (LeftInputAdapterNode) aotn.getObjectSinkPropagator().getSinks()[0]; ConditionalBranchNode cen1Node = ( ConditionalBranchNode ) liaNode.getSinkPropagator().getSinks()[0]; JoinNode bNode = ( JoinNode ) cen1Node.getSinkPropagator().getSinks()[0]; RuleTerminalNode rtn1 = ( RuleTerminalNode ) bNode.getSinkPropagator().getSinks()[0]; FactHandle bFh = wm.insert( new LinkingTest.B() ); wm.flushPropagations(); LiaNodeMemory liaMem = ( LiaNodeMemory ) wm.getNodeMemory( liaNode ); SegmentMemory smem = liaMem.getSegmentMemory(); assertEquals( 1, smem.getAllLinkedMaskTest() ); assertEquals( 4, smem.getLinkedNodeMask() ); // B links, but it will not trigger mask assertFalse( smem.isSegmentLinked() ); PathMemory pmem = ( PathMemory ) wm.getNodeMemory(rtn1); assertEquals( 1, pmem.getAllLinkedMaskTest() ); assertEquals( 0, pmem.getLinkedSegmentMask() ); assertFalse( pmem.isRuleLinked() ); wm.insert(new LinkingTest.A()); wm.flushPropagations(); assertEquals( 5, smem.getLinkedNodeMask() ); // A links in segment assertTrue( smem.isSegmentLinked() ); assertEquals( 1, pmem.getLinkedSegmentMask() ); assertTrue( pmem.isRuleLinked() ); wm.delete(bFh); // retract B does not unlink the rule wm.flushPropagations(); assertEquals( 1, pmem.getLinkedSegmentMask() ); assertTrue( pmem.isRuleLinked() ); } @Test public void testBranchCEMultipleSegments() throws Exception { KnowledgeBase kbase = buildKnowledgeBase( " $a : A() \n", // r1 " $a : A() \n" + " if ( $a != null ) do[t1] \n" + " B() \n", // r2 " $a : A() \n"+ " if ( $a != null ) do[t1] \n" + " B() \n" + " C() \n" // r3 ); InternalWorkingMemory wm = ((InternalWorkingMemory)kbase.newStatefulKnowledgeSession()); ObjectTypeNode aotn = getObjectTypeNode(kbase, LinkingTest.A.class ); LeftInputAdapterNode liaNode = (LeftInputAdapterNode) aotn.getObjectSinkPropagator().getSinks()[0]; ConditionalBranchNode cen1Node = ( ConditionalBranchNode ) liaNode.getSinkPropagator().getSinks()[1]; JoinNode bNode = ( JoinNode ) cen1Node.getSinkPropagator().getSinks()[0]; RuleTerminalNode rtn2 = ( RuleTerminalNode ) bNode.getSinkPropagator().getSinks()[0]; JoinNode cNode = ( JoinNode ) bNode.getSinkPropagator().getSinks()[1]; RuleTerminalNode rtn3 = ( RuleTerminalNode ) cNode.getSinkPropagator().getSinks()[0]; FactHandle bFh = wm.insert( new LinkingTest.B() ); FactHandle cFh = wm.insert( new LinkingTest.C() ); wm.flushPropagations(); BetaMemory bNodeBm = ( BetaMemory ) wm.getNodeMemory( bNode ); SegmentMemory bNodeSmem = bNodeBm.getSegmentMemory(); assertEquals( 0, bNodeSmem.getAllLinkedMaskTest() ); // no beta nodes before branch CE, so never unlinks assertEquals( 2, bNodeSmem.getLinkedNodeMask() ); PathMemory pmemr2 = ( PathMemory ) wm.getNodeMemory(rtn2); assertEquals( 1, pmemr2.getAllLinkedMaskTest() ); assertEquals( 2, pmemr2.getLinkedSegmentMask() ); assertEquals( 3, pmemr2.getSegmentMemories().length ); assertFalse( pmemr2.isRuleLinked() ); PathMemory pmemr3 = ( PathMemory ) wm.getNodeMemory(rtn3); assertEquals( 1, pmemr3.getAllLinkedMaskTest() ); // notice only the first segment links assertEquals( 3, pmemr3.getSegmentMemories().length ); assertFalse( pmemr3.isRuleLinked() ); BetaMemory cNodeBm = ( BetaMemory ) wm.getNodeMemory( cNode ); SegmentMemory cNodeSmem = cNodeBm.getSegmentMemory(); assertEquals( 1, cNodeSmem.getAllLinkedMaskTest() ); assertEquals( 1, cNodeSmem.getLinkedNodeMask() ); wm.insert(new LinkingTest.A()); wm.flushPropagations(); assertTrue( pmemr2.isRuleLinked() ); assertTrue( pmemr3.isRuleLinked() ); wm.delete(bFh); // retract B does not unlink the rule wm.delete(cFh); // retract C does not unlink the rule wm.flushPropagations(); assertEquals( 3, pmemr2.getLinkedSegmentMask() ); // b segment never unlinks, as it has no impact on path unlinking anyway assertTrue( pmemr2.isRuleLinked() ); assertEquals( 3, pmemr3.getLinkedSegmentMask() ); // b segment never unlinks, as it has no impact on path unlinking anyway assertTrue( pmemr3.isRuleLinked() ); } private KnowledgeBase buildKnowledgeBase(String... rules) { String str = ""; str += "package org.kie \n"; str += "import " + LinkingTest.A.class.getCanonicalName() + "\n" ; str += "import " + LinkingTest.B.class.getCanonicalName() + "\n" ; str += "import " + LinkingTest.C.class.getCanonicalName() + "\n" ; str += "global java.util.List list \n"; int i = 0; for ( String lhs : rules) { str += "rule rule" + (i++) +" when \n"; str += lhs; str += "then \n"; str += "then[t1] \n"; str += "end \n"; } KnowledgeBuilder kbuilder = KnowledgeBuilderFactory.newKnowledgeBuilder(); kbuilder.add( ResourceFactory.newByteArrayResource(str.getBytes()), ResourceType.DRL ); assertFalse( kbuilder.getErrors().toString(), kbuilder.hasErrors() ); KnowledgeBase kbase = KnowledgeBaseFactory.newKnowledgeBase(); kbase.addKnowledgePackages( kbuilder.getKnowledgePackages() ); return kbase; } public ObjectTypeNode getObjectTypeNode(KnowledgeBase kbase, Class<?> nodeClass) { List<ObjectTypeNode> nodes = ((KnowledgeBaseImpl)kbase).getRete().getObjectTypeNodes(); for ( ObjectTypeNode n : nodes ) { if ( ((ClassObjectType)n.getObjectType()).getClassType() == nodeClass ) { return n; } } return null; } }