/*
* Copyright 2016 Red Hat, Inc. and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.drools.compiler.integrationtests;
import org.drools.core.InitialFact;
import org.drools.core.base.ClassObjectType;
import org.drools.core.common.BaseNode;
import org.drools.core.common.NetworkNode;
import org.drools.core.common.RuleBasePartitionId;
import org.drools.core.impl.InternalKnowledgeBase;
import org.drools.core.reteoo.BetaNode;
import org.drools.core.reteoo.CompositePartitionAwareObjectSinkAdapter;
import org.drools.core.reteoo.EntryPointNode;
import org.drools.core.reteoo.LeftTupleSource;
import org.drools.core.reteoo.ObjectSink;
import org.drools.core.reteoo.ObjectSinkPropagator;
import org.drools.core.reteoo.ObjectSource;
import org.drools.core.reteoo.ObjectTypeNode;
import org.drools.core.reteoo.Rete;
import org.drools.core.reteoo.Sink;
import org.drools.core.reteoo.TerminalNode;
import org.junit.Test;
import org.kie.api.io.ResourceType;
import org.kie.internal.conf.MultithreadEvaluationOption;
import org.kie.internal.utils.KieHelper;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertSame;
public class NodesPartitioningTest {
@Test
public void test2Partitions() {
String drl = ruleA(1) + ruleB(2) + ruleC(2) + ruleD(1) +
ruleD(2) + ruleC(1) + ruleA(2) + ruleB(1);
checkDrl( drl );
}
@Test
public void testPartitioningWithSharedNodes() {
StringBuilder sb = new StringBuilder( 400 );
for (int i = 1; i < 4; i++) {
sb.append( getRule( i ) );
}
for (int i = 1; i < 4; i++) {
sb.append( getNotRule( i ) );
}
checkDrl( sb.toString() );
}
private void checkDrl(String drl) {
InternalKnowledgeBase kbase = (InternalKnowledgeBase) new KieHelper().addContent( drl, ResourceType.DRL )
.build( MultithreadEvaluationOption.YES );
Rete rete = kbase.getRete();
for (EntryPointNode entryPointNode : rete.getEntryPointNodes().values()) {
traverse( entryPointNode );
}
}
private void traverse(BaseNode node ) {
checkNode(node);
Sink[] sinks = node.getSinks();
if (sinks != null) {
for (Sink sink : sinks) {
if (sink instanceof BaseNode) {
traverse((BaseNode)sink);
}
}
}
}
private void checkNode(NetworkNode node) {
if (node instanceof EntryPointNode) {
assertSame( RuleBasePartitionId.MAIN_PARTITION, node.getPartitionId() );
} else if (node instanceof ObjectTypeNode) {
assertSame( RuleBasePartitionId.MAIN_PARTITION, node.getPartitionId() );
checkPartitionedSinks((ObjectTypeNode) node);
} else if (node instanceof ObjectSource ) {
ObjectSource source = ( (ObjectSource) node ).getParentObjectSource();
if ( !(source instanceof ObjectTypeNode) ) {
assertSame( source.getPartitionId(), node.getPartitionId() );
}
} else if (node instanceof BetaNode ) {
ObjectSource rightInput = ( (BetaNode) node ).getRightInput();
if ( !(rightInput instanceof ObjectTypeNode) ) {
assertSame( rightInput.getPartitionId(), node.getPartitionId() );
}
LeftTupleSource leftInput = ( (BetaNode) node ).getLeftTupleSource();
assertSame( leftInput.getPartitionId(), node.getPartitionId() );
} else if (node instanceof TerminalNode ) {
LeftTupleSource leftInput = ( (TerminalNode) node ).getLeftTupleSource();
assertSame( leftInput.getPartitionId(), node.getPartitionId() );
}
}
private void checkPartitionedSinks(ObjectTypeNode otn) {
if ( InitialFact.class.isAssignableFrom( ( (ClassObjectType) otn.getObjectType() ).getClassType() ) ) {
return;
}
CompositePartitionAwareObjectSinkAdapter sinkPropagator = (CompositePartitionAwareObjectSinkAdapter) otn.getObjectSinkPropagator();
ObjectSinkPropagator[] propagators = sinkPropagator.getPartitionedPropagators();
for (int i = 0; i < propagators.length; i++) {
for (ObjectSink sink : propagators[i].getSinks()) {
assertEquals( sink + " on " + sink.getPartitionId() + " is expcted to be on propagator " + i,
i, sink.getPartitionId().getId() % propagators.length );
}
}
}
private String ruleA(int i) {
return "rule Ra" + i + " when\n" +
" $i : Integer( this == " + i + " )\n" +
" $s : String( length == $i )\n" +
" Integer( this == $s.length )\n" +
"then end\n";
}
private String ruleB(int i) {
return "rule Rb" + i + " when\n" +
" $i : Integer( this == " + i + " )\n" +
" $s : String( this == $i.toString )\n" +
" Integer( this == $s.length )\n" +
"then end\n";
}
private String ruleC(int i) {
return "rule Rc" + i + " when\n" +
" $i : Integer( this == " + i + " )\n" +
" $s : String( length == $i )\n" +
" Integer( this == $i+1 )\n" +
"then end\n";
}
private String ruleD(int i) {
return "rule Rd" + i + " when\n" +
" $i : Integer( this == " + i + " )\n" +
" $s : String( length == $i )\n" +
"then end\n";
}
private String getRule(int i) {
return "rule R" + i + " when\n" +
" $i : Integer( this == " + i + " )" +
" String( this == $i.toString )\n" +
"then end\n";
}
private String getNotRule(int i) {
return "rule Rnot" + i + " when\n" +
" String( this == \"" + i + "\" )\n" +
" not Integer( this == " + i + " )" +
"then end\n";
}
}