/*
* Copyright 2017 Red Hat, Inc. and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.drools.core.ruleunit;
import org.drools.core.impl.RuleUnitExecutorSession;
import org.drools.core.spi.Activation;
import org.kie.api.runtime.rule.RuleUnit;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
public class RuleUnitGuardSystem {
private static final RuleUnit ROOT_UNIT = new RuleUnit() { };
private final RuleUnitExecutorSession session;
private Map<Guard, Guard> guardMap = new HashMap<>();
private Map<Activation, List<Guard>> guardsByActivation = new HashMap<>();
private Map<RuleUnit, Set<Guard>> guardsByActivatingUnit = new HashMap<>();
public RuleUnitGuardSystem( RuleUnitExecutorSession session ) {
this.session = session;
}
public void registerGuard( RuleUnit ruleUnit, Activation activation ) {
Guard g = new Guard( ruleUnit, activation.getRule() );
Guard guard = guardMap.computeIfAbsent( g, x -> g );
guard.addActivation(activation);
guardsByActivation.computeIfAbsent( activation, a -> new ArrayList<>() ).add( guard );
guardsByActivatingUnit.computeIfAbsent( getCurrentRuleUnit(), ru -> new HashSet<>() ) .add(guard);
}
public void removeActivation( Activation activation ) {
List<Guard> guards = guardsByActivation.get( activation );
if (guards == null) {
return;
}
guards.removeIf( guard -> {
guard.removeActivation( activation );
if ( !guard.isActive() ) {
guardMap.remove( guard );
guardsByActivatingUnit.computeIfPresent( getCurrentRuleUnit(), ( s, gs ) -> {
gs.remove( guard );
return gs.isEmpty() ? null : gs;
} );
return true;
}
return false;
} );
if (guards.isEmpty()) {
guardsByActivation.remove( activation );
}
}
private RuleUnit getCurrentRuleUnit() {
return session.getCurrentRuleUnit() != null ? session.getCurrentRuleUnit() : ROOT_UNIT;
}
public int fireActiveUnits() {
return fireActiveUnits( ROOT_UNIT );
}
public int fireActiveUnits(RuleUnit ruleUnit) {
return fireActiveUnits(ruleUnit, new HashSet<>());
}
private int fireActiveUnits(RuleUnit ruleUnit, Set<RuleUnit> firedUnits) {
Set<Guard> guards = guardsByActivatingUnit.get(ruleUnit);
if (guards == null) {
return 0;
}
int result = 0;
while (true) {
Optional<RuleUnit> unit = guards.stream().map( Guard::getGuardedUnit )
.filter( u -> !firedUnits.contains( u ) ).findFirst();
if (!unit.isPresent()) {
break;
}
RuleUnit firingUnit = unit.get();
result += session.internalExecuteUnit( firingUnit );
firedUnits.add(firingUnit);
result += fireActiveUnits( firingUnit, firedUnits );
}
return result;
}
}