/*
* Copyright 2015 Red Hat, Inc. and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.drools.compiler.integrationtests.incrementalcompilation;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.drools.core.reteoo.ReteDumper;
import org.junit.Assert;
import org.kie.api.KieBase;
import org.kie.api.io.ResourceType;
import org.kie.api.runtime.KieSession;
import org.kie.api.runtime.rule.FactHandle;
import org.kie.internal.KnowledgeBase;
import org.kie.internal.KnowledgeBaseFactory;
import org.kie.internal.builder.KnowledgeBuilder;
import org.kie.internal.builder.KnowledgeBuilderFactory;
import org.kie.internal.io.ResourceFactory;
import org.kie.internal.runtime.StatefulKnowledgeSession;
public class TestContext {
private final Map<FactHandle, Object> actualSessionFacts = new HashMap<FactHandle, Object>();
private final List<TestOperation> executedOperations = new ArrayList<TestOperation>();
private final Map<String, Object> sessionGlobals;
private final String rulesPackageName;
private final List resultsList;
private StatefulKnowledgeSession session;
private boolean failFast = true;
private final List<String> errorMessages = new ArrayList<String>();
public TestContext(final String rulesPackageName, final Map<String, Object> sessionGlobals,
final List resultsList) {
this.rulesPackageName = rulesPackageName;
this.sessionGlobals = sessionGlobals;
this.resultsList = resultsList;
}
public TestContext(final String rulesPackageName, final Map<String, Object> sessionGlobals,
final List resultsList, final boolean failFast) {
this.rulesPackageName = rulesPackageName;
this.sessionGlobals = sessionGlobals;
this.resultsList = resultsList;
this.failFast = failFast;
}
public void executeTestOperations(final List<TestOperation> testOperations) {
for (TestOperation testOperation : testOperations) {
try {
executeTestOperation(testOperation);
} catch (Exception e) {
throw new RuntimeException(createTestFailMessage(testOperations, null, null), e);
}
}
}
public void executeTestOperation(final TestOperation testOperation) {
final TestOperationType testOperationType = testOperation.getType();
final Object testOperationParameter = testOperation.getParameter();
if (testOperationType != TestOperationType.CREATE_SESSION) {
checkSessionInitialized();
}
switch (testOperationType) {
case CREATE_SESSION:
createSession((String[]) testOperationParameter, false);
break;
case ADD_RULES:
addRules((String[]) testOperationParameter, false);
break;
case ADD_RULES_REINSERT_OLD:
addRules((String[]) testOperationParameter, true);
break;
case REMOVE_RULES:
removeRules((String[]) testOperationParameter);
break;
case FIRE_RULES:
session.fireAllRules();
break;
case INSERT_FACTS:
insertFacts((Object[]) testOperationParameter);
break;
case REMOVE_FACTS:
removeFacts((FactHandle[]) testOperationParameter);
break;
case CHECK_RESULTS:
checkResults((String[]) testOperationParameter);
break;
case DUMP_RETE:
ReteDumper.dumpRete((KieSession) session);
break;
default:
throw new IllegalArgumentException("Unsupported test operation: " + testOperationType + "!");
}
executedOperations.add(testOperation);
}
public void dumpRete() {
checkSessionInitialized();
ReteDumper.dumpRete((KieSession) session);
}
public Map<FactHandle, Object> getActualSessionFacts() {
return actualSessionFacts;
}
public Set<FactHandle> getActualSessionFactHandles() {
return actualSessionFacts.keySet();
}
public List<TestOperation> getExecutedOperations() {
return executedOperations;
}
public void clearExecutedOperations() {
executedOperations.clear();
}
public StatefulKnowledgeSession getSession() {
return session;
}
public Object getSessionGlobal(final String sessionGlobalName) {
return sessionGlobals.get(sessionGlobalName);
}
public boolean isFailFast() {
return failFast;
}
public void setFailFast(final boolean failFast) {
this.failFast = failFast;
}
public List<String> getErrorMessages() {
return errorMessages;
}
public void clearErrorMessages() {
errorMessages.clear();
}
private void checkSessionInitialized() {
if (session == null) {
throw new IllegalStateException("Session is not initialized! Please, initialize session first.");
}
}
private void addRules(final String[] drls, final boolean reuseKieBaseWhenAddingRules) {
for (String drl : drls) {
final KnowledgeBuilder kBuilder;
if (reuseKieBaseWhenAddingRules) {
kBuilder = createKnowledgeBuilder(session.getKieBase(), drl);
} else {
kBuilder = createKnowledgeBuilder(null, drl);
}
session.getKieBase().addKnowledgePackages(kBuilder.getKnowledgePackages());
}
}
private void removeRules(final String[] ruleNames) {
final KieBase kieBase = session.getKieBase();
for (String ruleName : ruleNames) {
kieBase.removeRule(rulesPackageName, ruleName);
}
}
private void insertFacts(final Object[] facts) {
for (Object fact: facts) {
actualSessionFacts.put(session.insert(fact), fact);
}
}
private void removeFacts(final FactHandle[] factHandles) {
for (FactHandle factHandle : factHandles) {
session.delete(factHandle);
actualSessionFacts.remove(factHandle);
}
}
private void checkResults(final String[] expectedResults) {
final Set<String> expectedResultsSet = new HashSet<String>();
expectedResultsSet.addAll(Arrays.asList(expectedResults));
if (((expectedResultsSet.size() > 0) && (resultsList.size() == 0))
|| !expectedResultsSet.containsAll(resultsList)
|| !resultsList.containsAll(expectedResultsSet)) {
if (failFast) {
Assert.fail(createTestFailMessage(executedOperations, expectedResultsSet, resultsList));
} else {
errorMessages.add(createTestFailMessage(executedOperations, expectedResultsSet, resultsList));
}
}
resultsList.clear();
}
private void createSession(final String[] drls, final boolean reuseKieBaseWhenAddingRules) {
if (session != null) {
actualSessionFacts.clear();
session.dispose();
}
session = buildSessionInSteps(drls, reuseKieBaseWhenAddingRules);
if (sessionGlobals != null) {
insertGlobals(session, sessionGlobals);
}
}
private StatefulKnowledgeSession buildSessionInSteps(final String[] drls,
final boolean reuseKieBaseWhenAddingRules) {
if (drls == null || drls.length == 0) {
return KnowledgeBaseFactory.newKnowledgeBase().newStatefulKnowledgeSession();
} else {
String drl = drls[0];
final KnowledgeBuilder kbuilder = createKnowledgeBuilder(null, drl);
final KnowledgeBase kbase = KnowledgeBaseFactory.newKnowledgeBase();
kbase.addKnowledgePackages(kbuilder.getKnowledgePackages());
final StatefulKnowledgeSession kSession = kbase.newStatefulKnowledgeSession();
for (int i = 1; i < drls.length; i++) {
drl = drls[i];
final KnowledgeBuilder kbuilder2;
if (reuseKieBaseWhenAddingRules) {
kbuilder2 = createKnowledgeBuilder(kSession.getKieBase(), drl);
} else {
kbuilder2 = createKnowledgeBuilder(null, drl);
}
kSession.getKieBase().addKnowledgePackages(kbuilder2.getKnowledgePackages());
}
return kSession;
}
}
private KnowledgeBuilder createKnowledgeBuilder(final KnowledgeBase kbase, final String drl) {
final KnowledgeBuilder kbuilder;
if (kbase == null) {
kbuilder = KnowledgeBuilderFactory.newKnowledgeBuilder();
} else {
kbuilder = KnowledgeBuilderFactory.newKnowledgeBuilder(kbase);
}
kbuilder.add(ResourceFactory.newByteArrayResource(drl.getBytes()), ResourceType.DRL);
if (kbuilder.hasErrors()) {
throw new RuntimeException("Knowledge contains errors: " + kbuilder.getErrors().toString());
}
return kbuilder;
}
private void insertGlobals(final StatefulKnowledgeSession session, final Map<String, Object> globals) {
for (Map.Entry<String, Object> globalEntry : globals.entrySet()) {
session.setGlobal(globalEntry.getKey(), globalEntry.getValue());
}
}
private String createTestFailMessage(final List<TestOperation> testOperations,
final Collection<String> expectedResults, final Collection<String> actualResults) {
final StringBuilder messageBuilder = new StringBuilder();
final String lineSeparator = System.getProperty("line.separator");
messageBuilder.append("Expected results are different than actual after operations:" + lineSeparator);
int index = 1;
for (TestOperation testOperation : testOperations) {
messageBuilder.append(index + ". " + testOperation.toString());
messageBuilder.append(lineSeparator);
index++;
}
if (expectedResults != null && actualResults != null) {
messageBuilder.append( "Expected results: " + lineSeparator + "[" );
for ( String expectedResult : expectedResults ) {
messageBuilder.append( expectedResult + " " );
}
messageBuilder.append( "]" + lineSeparator );
messageBuilder.append( "Actual results: " + lineSeparator + "[" );
for ( String actualResult : actualResults ) {
messageBuilder.append( actualResult + " " );
}
}
messageBuilder.append("]" + lineSeparator);
return messageBuilder.toString();
}
}