/* * Copyright 2005 Red Hat, Inc. and/or its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.drools.core.reteoo; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.Externalizable; import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Queue; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import org.drools.core.common.BaseNode; import org.drools.core.common.DroolsObjectInputStream; import org.drools.core.common.DroolsObjectOutputStream; import org.drools.core.common.InternalWorkingMemory; import org.drools.core.common.MemoryFactory; import org.drools.core.common.NetworkNode; import org.drools.core.definitions.rule.impl.RuleImpl; import org.drools.core.impl.InternalKnowledgeBase; import org.drools.core.phreak.AddRemoveRule; import org.drools.core.rule.InvalidPatternException; import org.drools.core.rule.WindowDeclaration; import org.kie.api.definition.rule.Rule; import static org.drools.core.impl.StatefulKnowledgeSessionImpl.DEFAULT_RULE_UNIT; /** * Builds the Rete-OO network for a <code>Package</code>. * */ public class ReteooBuilder implements Externalizable { // ------------------------------------------------------------ // Instance members // ------------------------------------------------------------ private static final long serialVersionUID = 510l; /** The RuleBase */ private transient InternalKnowledgeBase kBase; private Map<String, BaseNode[]> rules; private Map<String, BaseNode[]> queries; private Map<String, WindowNode> namedWindows; private transient RuleBuilder ruleBuilder; private IdGenerator idGenerator; // ------------------------------------------------------------ // Constructors // ------------------------------------------------------------ public ReteooBuilder() { } /** * Construct a <code>Builder</code> against an existing <code>Rete</code> * network. */ public ReteooBuilder( final InternalKnowledgeBase kBase ) { this.kBase = kBase; this.rules = new HashMap<String, BaseNode[]>(); this.queries = new HashMap<String, BaseNode[]>(); this.namedWindows = new HashMap<String, WindowNode>(); //Set to 1 as Rete node is set to 0 this.idGenerator = new IdGenerator(); this.ruleBuilder = kBase.getConfiguration().getComponentFactory().getRuleBuilderFactory().newRuleBuilder(); } // ------------------------------------------------------------ // Instance methods // ------------------------------------------------------------ /** * Add a <code>Rule</code> to the network. * * @param rule * The rule to add. * @throws InvalidPatternException */ public synchronized void addRule(final RuleImpl rule) throws InvalidPatternException { final List<TerminalNode> terminals = this.ruleBuilder.addRule( rule, this.kBase ); BaseNode[] nodes = terminals.toArray( new BaseNode[terminals.size()] ); this.rules.put( rule.getFullyQualifiedName(), nodes ); if (rule.isQuery()) { this.queries.put( rule.getName(), nodes ); } } public void addEntryPoint( String id ) { this.ruleBuilder.addEntryPoint( id, this.kBase ); } public synchronized void addNamedWindow( WindowDeclaration window ) { final WindowNode wnode = this.ruleBuilder.addWindowNode( window, this.kBase ); this.namedWindows.put( window.getName(), wnode ); } public WindowNode getWindowNode( String name ) { return this.namedWindows.get( name ); } public IdGenerator getIdGenerator() { return this.idGenerator; } public synchronized BaseNode[] getTerminalNodes(final RuleImpl rule) { return getTerminalNodes( rule.getFullyQualifiedName() ); } public synchronized BaseNode[] getTerminalNodes(final String ruleName) { return this.rules.get( ruleName ); } public synchronized BaseNode[] getTerminalNodesForQuery(final String ruleName) { BaseNode[] nodes = this.queries.get( ruleName ); return nodes != null ? nodes : getTerminalNodes(ruleName); } public synchronized Map<String, BaseNode[]> getTerminalNodes() { return this.rules; } public synchronized void removeRules(List<RuleImpl> rulesToBeRemoved) { // reset working memories for potential propagation InternalWorkingMemory[] workingMemories = this.kBase.getWorkingMemories(); for (RuleImpl rule : rulesToBeRemoved) { if (rule.hasChildren() && !rulesToBeRemoved.containsAll( rule.getChildren() )) { throw new RuntimeException("Cannot remove parent rule " + rule + " without having removed all its chikdren"); } final RuleRemovalContext context = new RuleRemovalContext( rule ); context.setKnowledgeBase( kBase ); for ( BaseNode node : rules.remove( rule.getFullyQualifiedName() ) ) { removeTerminalNode( context, (TerminalNode) node, workingMemories ); } if ( rule.isQuery() ) { this.queries.remove( rule.getName() ); } if (rule.getParent() != null && !rulesToBeRemoved.contains( rule.getParent() )) { rule.getParent().removeChild( rule ); } } } public void removeTerminalNode(RuleRemovalContext context, TerminalNode tn, InternalWorkingMemory[] workingMemories) { AddRemoveRule.removeRule( tn, workingMemories, kBase ); BaseNode node = (BaseNode) tn; removeNodeAssociation(node, context.getRule()); resetMasks(removeNodes((AbstractTerminalNode)tn, workingMemories, context)); } private Collection<BaseNode> removeNodes(AbstractTerminalNode terminalNode, InternalWorkingMemory[] wms, RuleRemovalContext context) { Map<Integer, BaseNode> stillInUse = new HashMap<Integer, BaseNode>(); Collection<ObjectSource> alphas = new HashSet<ObjectSource>(); removePath(wms, context, stillInUse, alphas, terminalNode); Set<Integer> removedNodes = new HashSet<Integer>(); for (ObjectSource alpha : alphas) { removeObjectSource( wms, stillInUse, removedNodes, alpha, context ); } return stillInUse.values(); } /** * Path's must be removed starting from the outer most path, iterating towards the inner most path. * Each time it reaches a subnetwork beta node, the current path evaluation ends, and instead the subnetwork * path continues. */ private void removePath( InternalWorkingMemory[] wms, RuleRemovalContext context, Map<Integer, BaseNode> stillInUse, Collection<ObjectSource> alphas, PathEndNode endNode ) { LeftTupleNode[] nodes = endNode.getPathNodes(); for (int i = endNode.getPositionInPath(); i >= 0; i--) { BaseNode node = (BaseNode) nodes[i]; boolean removed = false; if ( NodeTypeEnums.isLeftTupleNode( node ) ) { removed = removeLeftTupleNode(wms, context, stillInUse, node); } if ( removed ) { // reteoo requires to call remove on the OTN for tuples cleanup if (NodeTypeEnums.isBetaNode(node) && !((BetaNode) node).isRightInputIsRiaNode()) { alphas.add(((BetaNode) node).getRightInput()); } else if (node.getType() == NodeTypeEnums.LeftInputAdapterNode) { alphas.add(((LeftInputAdapterNode) node).getObjectSource()); } } if (NodeTypeEnums.isBetaNode(node) && ((BetaNode) node).isRightInputIsRiaNode()) { endNode = (PathEndNode) ((BetaNode) node).getRightInput(); removePath(wms, context, stillInUse, alphas, endNode); return; } } } private boolean removeLeftTupleNode(InternalWorkingMemory[] wms, RuleRemovalContext context, Map<Integer, BaseNode> stillInUse, BaseNode node) { boolean removed; removed = node.remove(context, this, wms); if (removed) { stillInUse.remove( node.getId() ); // phreak must clear node memories, although this should ideally be pushed into AddRemoveRule for (InternalWorkingMemory workingMemory : wms) { workingMemory.clearNodeMemory((MemoryFactory) node); } } else { stillInUse.put( node.getId(), node ); } return removed; } private void removeObjectSource(InternalWorkingMemory[] wms, Map<Integer, BaseNode> stillInUse, Set<Integer> removedNodes, ObjectSource node, RuleRemovalContext context ) { if (removedNodes.contains( node.getId() )) { return; } ObjectSource parent = node.getParentObjectSource(); boolean removed = node.remove( context, this, wms ); if ( !removed ) { stillInUse.put( node.getId(), node ); } else { stillInUse.remove(node.getId()); removedNodes.add(node.getId()); if ( node.getType() != NodeTypeEnums.ObjectTypeNode && node.getType() != NodeTypeEnums.AlphaNode ) { // phreak must clear node memories, although this should ideally be pushed into AddRemoveRule for (InternalWorkingMemory workingMemory : wms) { workingMemory.clearNodeMemory( (MemoryFactory) node); } } if (parent != null && parent.getType() != NodeTypeEnums.EntryPointNode) { removeObjectSource(wms, stillInUse, removedNodes, parent, context); } } } private void removeNodeAssociation(BaseNode node, Rule rule) { if (node == null || !node.removeAssociation( rule )) { return; } if (node instanceof LeftTupleNode) { removeNodeAssociation( ((LeftTupleNode)node).getLeftTupleSource(), rule ); } if ( NodeTypeEnums.isBetaNode( node ) ) { removeNodeAssociation( ((BetaNode) node).getRightInput(), rule ); } else if ( node.getType() == NodeTypeEnums.LeftInputAdapterNode ) { removeNodeAssociation( ((LeftInputAdapterNode) node).getObjectSource(), rule ); } else if ( node.getType() == NodeTypeEnums.AlphaNode ) { removeNodeAssociation( ((AlphaNode) node).getParentObjectSource(), rule ); } } private void resetMasks(Collection<BaseNode> nodes) { NodeSet leafSet = new NodeSet(); for ( BaseNode node : nodes ) { if ( node.getType() == NodeTypeEnums.AlphaNode ) { ObjectSource source = (AlphaNode) node; while ( true ) { source.resetInferredMask(); BaseNode parent = source.getParentObjectSource(); if (parent.getType() != NodeTypeEnums.AlphaNode) { break; } source = (ObjectSource)parent; } updateLeafSet(source, leafSet ); } else if( NodeTypeEnums.isBetaNode( node ) ) { BetaNode betaNode = ( BetaNode ) node; if ( betaNode.isInUse() ) { leafSet.add( betaNode ); } } else if ( NodeTypeEnums.isTerminalNode( node ) ) { RuleTerminalNode rtNode = ( RuleTerminalNode ) node; if ( rtNode.isInUse() ) { leafSet.add( rtNode ); } } } for ( BaseNode node : leafSet ) { if ( NodeTypeEnums.isTerminalNode( node ) ) { ((TerminalNode)node).initInferredMask(); } else { // else node instanceof BetaNode ((BetaNode)node).initInferredMask(); } } } private void updateLeafSet(BaseNode baseNode, NodeSet leafSet) { if ( baseNode.getType() == NodeTypeEnums.AlphaNode ) { for ( ObjectSink sink : ((AlphaNode) baseNode).getObjectSinkPropagator().getSinks() ) { if ( ((BaseNode)sink).isInUse() ) { updateLeafSet( ( BaseNode ) sink, leafSet ); } } } else if ( baseNode.getType() == NodeTypeEnums.LeftInputAdapterNode ) { for ( LeftTupleSink sink : ((LeftInputAdapterNode) baseNode).getSinkPropagator().getSinks() ) { if ( sink.getType() == NodeTypeEnums.RuleTerminalNode ) { leafSet.add( (BaseNode) sink ); } else if ( ((BaseNode)sink).isInUse() ) { updateLeafSet( ( BaseNode ) sink, leafSet ); } } } else if ( baseNode.getType() == NodeTypeEnums.EvalConditionNode ) { for ( LeftTupleSink sink : ((EvalConditionNode) baseNode).getSinkPropagator().getSinks() ) { if ( ((BaseNode)sink).isInUse() ) { updateLeafSet( ( BaseNode ) sink, leafSet ); } } } else if ( NodeTypeEnums.isBetaNode( baseNode ) ) { if ( baseNode.isInUse() ) { leafSet.add( baseNode ); } } } public static class IdGenerator implements Externalizable { private static final String DEFAULT_TOPIC = "DEFAULT_TOPIC"; private Map<String, InternalIdGenerator> generators = new ConcurrentHashMap<>(); public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { generators = (Map<String, InternalIdGenerator>) in.readObject(); } public void writeExternal(ObjectOutput out) throws IOException { out.writeObject( generators ); } public int getNextId() { return getNextId( DEFAULT_TOPIC ); } public int getNextId(String topic) { return generators.computeIfAbsent( topic, key -> new InternalIdGenerator( 1 ) ).getNextId(); } public synchronized void releaseId( RuleImpl rule, NetworkNode node ) { generators.get( DEFAULT_TOPIC ).releaseId( node.getId() ); if (node instanceof MemoryFactory) { String unit = rule != null && rule.getRuleUnitClassName() != null ? rule.getRuleUnitClassName() : DEFAULT_RULE_UNIT; generators.get( unit ).releaseId( ( (MemoryFactory) node ).getMemoryId() ); } } public int getLastId() { return getLastId( DEFAULT_TOPIC ); } public int getLastId(String topic) { InternalIdGenerator gen = generators.get( topic ); return gen != null ? gen.getLastId() : 0; } } private static class InternalIdGenerator implements Externalizable { private static final long serialVersionUID = 510l; private Queue<Integer> recycledIds; private int nextId; public InternalIdGenerator() { } public InternalIdGenerator(final int firstId) { this.nextId = firstId; this.recycledIds = new LinkedList<Integer>(); } @SuppressWarnings("unchecked") public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { recycledIds = (Queue<Integer>) in.readObject(); nextId = in.readInt(); } public void writeExternal(ObjectOutput out) throws IOException { out.writeObject( recycledIds ); out.writeInt( nextId ); } public synchronized int getNextId() { Integer id = this.recycledIds.poll(); return ( id == null ) ? this.nextId++ : id; } public synchronized void releaseId(int id) { this.recycledIds.add( id ); } public int getLastId() { return this.nextId - 1; } } public void writeExternal(ObjectOutput out) throws IOException { boolean isDrools = out instanceof DroolsObjectOutputStream; DroolsObjectOutputStream droolsStream; ByteArrayOutputStream bytes; if ( isDrools ) { bytes = null; droolsStream = (DroolsObjectOutputStream) out; } else { bytes = new ByteArrayOutputStream(); droolsStream = new DroolsObjectOutputStream( bytes ); } droolsStream.writeObject( rules ); droolsStream.writeObject( queries ); droolsStream.writeObject( namedWindows ); droolsStream.writeObject( idGenerator ); if ( !isDrools ) { droolsStream.flush(); droolsStream.close(); bytes.close(); out.writeInt( bytes.size() ); out.writeObject( bytes.toByteArray() ); } } public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { boolean isDrools = in instanceof DroolsObjectInputStream; DroolsObjectInputStream droolsStream; ByteArrayInputStream bytes; if ( isDrools ) { bytes = null; droolsStream = (DroolsObjectInputStream) in; } else { bytes = new ByteArrayInputStream( (byte[]) in.readObject() ); droolsStream = new DroolsObjectInputStream( bytes ); } this.rules = (Map<String, BaseNode[]>) droolsStream.readObject(); this.queries = (Map<String, BaseNode[]>) droolsStream.readObject(); this.namedWindows = (Map<String, WindowNode>) droolsStream.readObject(); this.idGenerator = (IdGenerator) droolsStream.readObject(); if ( !isDrools ) { droolsStream.close(); bytes.close(); } } public void setRuleBase( InternalKnowledgeBase kBase ) { this.kBase = kBase; this.ruleBuilder = kBase.getConfiguration().getComponentFactory().getRuleBuilderFactory().newRuleBuilder(); } }