/*
* Copyright 2015 Red Hat, Inc. and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package org.drools.core.reteoo;
import org.drools.core.common.BaseNode;
import org.drools.core.common.NetworkNode;
import org.drools.core.impl.InternalKnowledgeBase;
import org.kie.api.KieBase;
import org.kie.api.runtime.KieSession;
import org.kie.internal.KnowledgeBase;
import org.kie.internal.runtime.KnowledgeRuntime;
import java.util.Arrays;
import java.util.Comparator;
public class ReteComparator {
private ReteComparator() { }
public static boolean areEqual(KieBase kbase1, KieBase kbase2) {
try {
compare( kbase1, kbase2 );
return true;
} catch (Exception e) {
System.out.println(e.getMessage());
return false;
}
}
public static void compare(KieBase kbase1, KieBase kbase2) {
compare( (InternalKnowledgeBase) kbase1, (InternalKnowledgeBase) kbase2 );
}
public static void compare(KnowledgeBase kbase1, KnowledgeBase kbase2) {
compare( (InternalKnowledgeBase) kbase1, (InternalKnowledgeBase) kbase2 );
}
public static void compare(KnowledgeRuntime session1, KnowledgeRuntime session2) {
compare( (InternalKnowledgeBase) session1.getKieBase(), (InternalKnowledgeBase) session2.getKieBase() );
}
public static void compare(KieSession session1, KieSession session2) {
compare( (InternalKnowledgeBase) session1.getKieBase(), (InternalKnowledgeBase) session2.getKieBase() );
}
public static void compare(InternalKnowledgeBase kBase1, InternalKnowledgeBase kBase2) {
compare( kBase1.getRete(), kBase2.getRete() );
}
public static void compare(Rete rete1, Rete rete2) {
for (EntryPointNode epn1 : rete1.getEntryPointNodes().values()) {
EntryPointNode epn2 = rete2.getEntryPointNode( epn1.getEntryPoint() );
compareNodes( epn1, epn2 );
}
}
private static void compareNodes(BaseNode node1, BaseNode node2) {
if (!node1.equals( node2 )) {
throw new RuntimeException( node1 + " and " + node2 + " are not equal as expected" );
}
Sink[] sinks1 = node1.getSinks();
Sink[] sinks2 = node2.getSinks();
if (sinks1 == null) {
if (sinks2 == null) {
return;
} else {
throw new RuntimeException( node1 + " has no sinks while " + node2 + " has " + sinks2.length + " sinks" );
}
} else if (sinks2 == null) {
throw new RuntimeException( node1 + " has " + sinks1.length + " sinks while " + node2 + " has 0 sinks" );
}
if (sinks1.length != sinks2.length) {
throw new RuntimeException( node1 + " has " + sinks1.length + " sinks while " + node2 + " has no sinks" );
}
Arrays.sort(sinks1, NODE_SORTER);
Arrays.sort(sinks2, NODE_SORTER);
for (int i = 0; i < sinks1.length; i++) {
if (sinks1[i] instanceof BaseNode) {
compareNodes( (BaseNode) sinks1[i], (BaseNode) sinks2[i] );
}
}
}
public static final NetworkNodeComparator NODE_SORTER = new NetworkNodeComparator();
public static class NetworkNodeComparator implements Comparator<NetworkNode> {
@Override
public int compare( NetworkNode n1, NetworkNode n2 ) {
return n1.getId() - n2.getId();
}
}
}