/* * Copyright 2015 Red Hat, Inc. and/or its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.drools.compiler.oopath; import org.kie.api.KieBase; import org.kie.api.io.ResourceType; import org.kie.api.runtime.KieSession; import org.kie.internal.utils.KieHelper; import java.util.ArrayList; import java.util.List; import java.util.Random; public class RecursiveQueryBenchmark { private static final String RELATIONAL_DRL = "import " + Node.class.getCanonicalName() + ";\n" + "import " + Edge.class.getCanonicalName() + ";\n" + "import " + List.class.getCanonicalName() + ";\n" + "query findNodesWithValue( int $id, int $value, List list )\n" + " $n: Node( id == $id, $v : value ) " + " eval( $v != $value || ( $v == $value && list.add( $n ) ) )\n" + " Edge( fromId == $id, $toId : toId ) " + " findNodesWithValue( $toId, $value, list; )\n" + "end\n"; private static final String RELATIONAL_DRL_OLD = "import " + Node.class.getCanonicalName() + ";\n" + "import " + Edge.class.getCanonicalName() + ";\n" + "query findNodesWithValue( int $fromId, int $toId, int $value )\n" + " ( Edge( fromId == $fromId, toId == $toId ) and Node( id == $toId, value == $value ) )\n" + " or\n" + " ( Edge( fromId == $fromId, $childId : toId ) and findNodesWithValue( $childId, $toId, $value; ) )\n" + "end\n" + "\n" + "rule R when\n" + " Node( root == true, $rootId : id )\n" + " accumulate( findNodeWithValue($rootId, $nodeId, 0;) ; $result : count($nodeId) )\n" + "then\n" + " System.out.println( $result );\n" + "end\n"; private static final String FROM_DRL = "import " + Node.class.getCanonicalName() + ";\n" + "import " + Edge.class.getCanonicalName() + ";\n" + "import " + List.class.getCanonicalName() + ";\n" + "query findNodesWithValue( Node $from, int $value, List list )\n" + " Edge( $n : to, $v : to.value ) from $from.outEdges\n" + " eval( $v != $value || ( $v == $value && list.add( $n ) ) )\n" + " findNodesWithValue( $n, $value, list; )\n" + "end\n"; private static final String FROM_DRL_OLD = "import " + Node.class.getCanonicalName() + ";\n" + "import " + Edge.class.getCanonicalName() + ";\n" + "query findNodesWithValue( Node $from, Node $to, int $value )\n" + " Edge( to.value == $value, $to := to ) from $from.outEdges\n" + " or\n" + " ( Edge( $child : to ) from $from.outEdges and findNodesWithValue( $child, $to, $value; ) )\n" + "end\n" + "\n" + "rule R when\n" + " $root: Node( root == true )\n" + " accumulate( findNodeWithValue($root, $node, 0;) ; $result : count($node) )\n" + "then\n" + " System.out.println( $result );\n" + "end\n"; private static final String XPATH_DRL = "import " + Node.class.getCanonicalName() + ";\n" + "import " + Edge.class.getCanonicalName() + ";\n" + "import " + List.class.getCanonicalName() + ";\n" + "query findNodesWithValue( Node $from, int $value, List list )\n" + " Node( id == $from.id, $n: /outEdges/to )\n" + " eval( $n.getValue() != $value || ( $n.getValue() == $value && list.add( $n ) ) )\n" + " findNodesWithValue( $n, $value, list; )\n" + "end\n"; private static final String XPATH_DRL_OLD = "import " + Node.class.getCanonicalName() + ";\n" + "import " + Edge.class.getCanonicalName() + ";\n" + "query findNodesWithValue( Node $from, Node $to, int $value )\n" + " Node( this == $from, $to := /outEdges/to[value == $value] )\n" + " or\n" + " ( Node( this == $from, $child : /outEdges/to ) and findNodesWithValue( $child, $to, $value; ) )\n" + "end\n" + "\n" + "rule R when\n" + " $root: Node( root == true )\n" + " accumulate( findNodeWithValue($root, $node, 0;) ; $result : count($node) )\n" + "then\n" + " System.out.println( $result );\n" + "end\n"; public static void main( String[] args ) { int n = 1000; for (int i = 0; i < 5; i++) { System.out.println( "-------------------------------------" ); System.out.println( "Running with " + n + " nodes" ); System.out.println( "Relational version" ); runTest( new RelationalTest(), n ); System.out.println( "From version" ); runTest( new FromTest(), n ); n *= 2; System.gc(); try { Thread.sleep( 5000L ); } catch (InterruptedException e) { throw new RuntimeException( e ); } System.gc(); } } private static void runTest(Test test, int n) { KieBase kbase = getKieBase(test.getDrl()); // warmup for (int i = 0; i < 3; i++) { test.runTest(kbase, n); System.gc(); } BenchmarkResult batch = new BenchmarkResult("Batch"); for (int i = 0; i < 10; i++) { long[] result = test.runTest(kbase, n); batch.accumulate(result[0]); System.gc(); } System.out.println(batch); } private static KieBase getKieBase(String drl) { return new KieHelper().addContent(drl, ResourceType.DRL).build(); } interface Test { long[] runTest(KieBase kbase, int n); String getDrl(); } private static class RelationalTest implements Test { @Override public long[] runTest(KieBase kbase, int n) { return execTest(kbase, n, true); } @Override public String getDrl() { return RELATIONAL_DRL; } } private static class FromTest implements Test { @Override public long[] runTest(KieBase kbase, int n) { return execTest( kbase, n, false ); } @Override public String getDrl() { return FROM_DRL; } } public static long[] execTest(KieBase kbase, int n, boolean isRelational) { KieSession ksession = kbase.newKieSession(); Node root = generateTree( ksession, n, isRelational ); List list = new ArrayList(); long start = System.nanoTime(); ksession.getQueryResults( "findNodesWithValue", isRelational ? root.getId() : root, 0, list ); ksession.fireAllRules(); long[] result = new long[]{ (System.nanoTime() - start) }; //System.out.println( list.size() ); ksession.dispose(); return result; } private static Node generateTree( KieSession ksession, int n, boolean insertAll ) { final Random RANDOM = new Random(0); Node root = new Node(1); root.setRoot( true ); ksession.insert( root ); List<Node> nodes = new ArrayList<Node>(n); for (int i = 0; i < n; i++) { Node node = new Node(i / 10); nodes.add( node ); } List<Node> nodesInTree = new ArrayList<Node>(n); nodesInTree.add( root ); while (!nodes.isEmpty()) { Node parent = nodesInTree.get( RANDOM.nextInt( nodesInTree.size() ) ); Node node = nodes.remove( RANDOM.nextInt( nodes.size() ) ); Edge edge = new Edge( parent, node ); parent.addOutEdge( edge ); nodesInTree.add( node ); if ( insertAll ) { ksession.insert( edge ); ksession.insert( node ); } } return root; } public static class Node { private static int ID_GENERATOR = 0; private final int id = ID_GENERATOR++; private final List<Edge> outEdges = new ArrayList<Edge>(); private final int value; private boolean root; public Node( int value ) { this.value = value; } public List<Edge> getOutEdges() { return outEdges; } public void addOutEdge( Edge edge) { outEdges.add(edge); } public int getValue() { return value; } public int getId() { return id; } @Override public boolean equals( Object o ) { if ( this == o ) return true; if ( o == null || getClass() != o.getClass() ) return false; Node node = (Node) o; return id == node.id; } @Override public int hashCode() { return id; } public boolean isRoot() { return root; } public void setRoot( boolean root ) { this.root = root; } @Override public String toString() { return "Node: " + id; } } public static class Edge { public final Node from; public final Node to; public Edge( Node from, Node to ) { this.from = from; this.to = to; } public Node getFrom() { return from; } public int getFromId() { return from.getId(); } public Node getTo() { return to; } public int getToId() { return to.getId(); } @Override public String toString() { return "Edge[" + getFromId() + ", " + getToId() + "]"; } } public static class BenchmarkResult { private final String name; private long min = Long.MAX_VALUE; private long max = 0; private long sum = 0; private int counter = 0; public BenchmarkResult(String name) { this.name = name; } public void accumulate(long result) { if (result < min) { min = result; } if (result > max) { max = result; } sum += result; counter++; } private long getAverage() { return (sum - min - max) / (counter - 2); } @Override public String toString() { return name + " results: min = " + min + "; max = " + max + "; avg = " + getAverage(); } } }