package org.drools.examples.performance;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import org.kie.api.KieServices;
import org.kie.api.builder.KieBuilder;
import org.kie.api.builder.KieFileSystem;
import org.kie.api.builder.KieRepository;
import org.kie.api.builder.Message;
import org.kie.api.definition.type.FactType;
import org.kie.api.io.ResourceType;
import org.kie.api.runtime.KieContainer;
import org.kie.api.runtime.StatelessKieSession;
import org.kie.internal.KnowledgeBase;
import org.kie.internal.KnowledgeBaseFactory;
import org.kie.internal.builder.KnowledgeBuilder;
import org.kie.internal.builder.KnowledgeBuilderFactory;
import org.kie.internal.io.ResourceFactory;
import java.util.ArrayList;
public class PerformanceExample {
public static void main(final String[] args) throws Exception {
final long numberOfRulesToBuild = 10;
boolean useAccumulate = true;
String dialect = "mvel"; //noticed performance difference between java and mvel dialects
boolean usekjars = false;
boolean collectionBasedRules = true;
System.out.println("********* Numbers of rules: " + numberOfRulesToBuild + " kjars: " + usekjars + " accumulate: " + useAccumulate + " dialect: " + dialect + " *********");
String rules = getRules(numberOfRulesToBuild, useAccumulate, dialect, collectionBasedRules);
//System.out.println(rules);
long startTime = System.currentTimeMillis();
StatelessKieSession kSession;
FactType ft;
if (usekjars) {
KieContainer kContainer = loadContainerFromString(rules);
kSession = kContainer.newStatelessKieSession();
ft = kContainer.getKieBase().getFactType("org.drools.examples.performance", "TransactionC");
} else {
/* Alternative way to load knowledge base without using kjars.
Found slowness issue with internalInvalidateSegmentPrototype() when number of rules are increased.*/
KnowledgeBase kbase = loadKnowledgeBaseFromString( rules );
kSession = kbase.newStatelessKieSession();
ft = kbase.getFactType("org.drools.examples.performance", "TransactionC");
}
long endTime = System.currentTimeMillis();
System.out.println("Total time to build and load knowledgebase: " + (endTime - startTime) + " ms" );
ArrayList output = new ArrayList();
kSession.setGlobal("mo", output);
Object o = ft.newInstance();
Gson gConverter = new GsonBuilder().setDateFormat("yyyy-MM-dd'T'HH:mm:ss").create();
Object fo = gConverter.fromJson(getFact(), o.getClass());
kSession.execute(fo); //initial execute
startTime = System.currentTimeMillis();
kSession.execute(fo);
endTime = System.currentTimeMillis();
System.out.println("Execution time: " + (endTime - startTime) + " ms" );
String rulesOutput = gConverter.toJson(output);
System.out.println(rulesOutput);
}
private static KieContainer loadContainerFromString(String rules) {
long startTime = System.currentTimeMillis();
KieServices ks = KieServices.Factory.get();
KieRepository kr = ks.getRepository();
KieFileSystem kfs = ks.newKieFileSystem();
kfs.write("src/main/resources/examples/pertest.drl", rules);
KieBuilder kb = ks.newKieBuilder(kfs);
kb.buildAll();
if (kb.getResults().hasMessages(Message.Level.ERROR)) {
throw new RuntimeException("Build Errors:\n" + kb.getResults().toString());
}
long endTime = System.currentTimeMillis();
System.out.println("Time to build rules : " + (endTime - startTime) + " ms" );
startTime = System.currentTimeMillis();
KieContainer kContainer = ks.newKieContainer(kr.getDefaultReleaseId());
endTime = System.currentTimeMillis();
System.out.println("Time to load container: " + (endTime - startTime) + " ms" );
return kContainer;
}
protected static KnowledgeBase loadKnowledgeBaseFromString(String... drlContentStrings) {
long startTime = System.currentTimeMillis();
KnowledgeBuilder kbuilder = KnowledgeBuilderFactory.newKnowledgeBuilder();
for (String drlContentString : drlContentStrings) {
kbuilder.add(ResourceFactory.newByteArrayResource(drlContentString
.getBytes()), ResourceType.DRL);
}
if (kbuilder.hasErrors()) {
throw new RuntimeException("Build Errors:\n" + kbuilder.getErrors());
}
long endTime = System.currentTimeMillis();
System.out.println("Time to build rules: " + (endTime - startTime) + " ms" );
startTime = System.currentTimeMillis();
KnowledgeBase kbase = KnowledgeBaseFactory.newKnowledgeBase();
kbase.addKnowledgePackages(kbuilder.getKnowledgePackages());
endTime = System.currentTimeMillis();
System.out.println("Time to create knowledgebase: " + (endTime - startTime) + " ms" );
return kbase;
}
private static String getFact() {
return "{\n" +
"\"TransactionNumber\": \"88882\",\n" +
"\"TrackingID\": \"T001\",\n" +
"\"CurrencyCode\": \"USD\",\n" +
"\"TransactionNetTotal\" : 100.0,\n" +
"\"StoreCode\": \"D001\",\n" +
"\"CardNumber\": \"3614838386\",\n" +
"\"TransactionDetails\": [\n" +
"{\n" +
"\"Quantity\": 25,\n" +
"\"ItemNumber\": \"SKU1_0\",\n" +
"\"BrandID\": \"Nike\",\n" +
"\"SKU\": \"SKU1\",\n" +
"\"ProductCategoryCode\" : \"Clothing\"\n" +
"}]\n" +
"}";
}
private static String getRules(long numberofRules, boolean useAccumulate, String dialect, boolean collectionBasedRules) {
final long startTime = System.currentTimeMillis();
StringBuilder sb = new StringBuilder("package org.drools.examples.performance;\n");
sb.append(getImportStatements());
sb.append("global ArrayList<Outcome> mo;");
sb.append(getDeclareStatements());
for (long l =0; l <numberofRules; l++) {
sb.append(createRule(l, useAccumulate, dialect, collectionBasedRules));
}
//sb.append(createRules2("mvel"));
final long endTime = System.currentTimeMillis();
System.out.println("Time to generate: " + (endTime - startTime) + " ms");
return sb.toString();
}
private static String createRule(long number, boolean useAccumulate, String dialect, boolean collectionBasedRules) {
if (collectionBasedRules) {
return createCollectionRule( number, useAccumulate, dialect );
} else {
return createRule( number, useAccumulate, dialect );
}
}
private static String createRule(long number, boolean useAccumulate, String dialect) {
String s = "" +
"rule \"rule" + number + "\" \n";
if (!dialect.isEmpty()) {
s = s + "dialect \"" + dialect + "\"\n";
}
s = s + "when t : TransactionC(CurrencyCode == \"USD" + number + "\") \n";
if (useAccumulate) {
s = s + "accumulate($item:TransactionDetailsC() from t.TransactionDetails, $totQty: collectList($item.getQuantity()))\n";
}
s = s + "then \n" +
"mo.add(new Outcome(\"rule" + number + "\", t.getTransactionNumber()));\n" +
"end \n" ;
return s;
}
private static String createCollectionRule(long number, boolean useAccumulate, String dialect) {
long NumOfSKU = 10;
String sku = "";
String prefix = "";
for (long l =0; l <NumOfSKU; l++) {
sku += prefix + "\"SKU" + number + "_" + l + "\"";
prefix = ",";
}
String s = "" +
"rule \"rule" + number + "\" \n";
if (!dialect.isEmpty()) {
s = s + "dialect \"" + dialect + "\"\n";
}
s = s + "when t : TransactionC() \n" +
//"d: TransactionDetailsC(ItemNumber == \"SKU" + number + "\") from t.TransactionDetails \n";
"d: TransactionDetailsC(ItemNumber in (" + sku + ")) from t.TransactionDetails \n";
if (useAccumulate) {
s = s + "accumulate($item:TransactionDetailsC(ItemNumber in (" + sku + ")) from t.TransactionDetails, $totQty: collectList($item.getQuantity()))\n";
}
s = s + "then \n" +
"mo.add(new Outcome(\"rule" + number + "\", d.getBrandID()));\n" +
"end \n" ;
return s;
}
private static String createRules2(String dialect) {
return "" +
"rule \"r1\"\n" +
"dialect \"" + dialect + "\"\n" +
"when t : TransactionC(CurrencyCode == \"USD\") \n" +
"then \n" +
"mo.add(new Outcome(\"r1\" , t.getTransactionNumber()));\n" +
"end \n" +
"rule \"r2\"\n" +
"dialect \"" + dialect + "\"\n" +
"when t : TransactionC(CurrencyCode == \"USD\") \n" +
"then \n" +
"mo.add(new Outcome(\"r2\" , t.getTransactionNumber()));\n" +
"end \n" +
"rule \"r3\"\n" +
"dialect \"" + dialect + "\"\n" +
"when t : TransactionC(CurrencyCode == \"CAD\") \n" +
"then \n" +
"mo.add(new Outcome(\"r3\", t.getTransactionNumber()));\n" +
"end \n" +
"rule \"r4\"\n" +
"dialect \"" + dialect + "\"\n" +
"when t : TransactionC(CurrencyCode == \"USD\") \n" +
"then \n" +
"mo.add(new Outcome(\"r4\", t.getTransactionNumber()));\n" +
"end \n";
}
private static String getDeclareStatements() {
return "" +
"declare TransactionC \n" +
"CardNumber : String \n" +
"StoreCode : String \n" +
"TrackingID : String \n" +
"CurrencyCode : String \n" +
"TransactionNetTotal : Double \n" +
"TransactionNumber : String \n" +
"TransactionDetails : TransactionDetailsC[] \n" +
"end \n" +
"declare TransactionDetailsC \n" +
"ItemNumber : String \n" +
"BrandID : String \n" +
"SKU : String \n" +
"ProductCategoryCode : String \n" +
"Quantity : Double \n" +
"end\n" +
"declare Outcome \n" +
"RuleId : String \n" +
"OutcomeValue : String \n" +
"end \n";
}
private static String getImportStatements() {
return "import java.util.ArrayList \n" +
"import java.util.List \n";
}
}