/* * Copyright 2010 Red Hat, Inc. and/or its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.drools.core.reteoo; import java.io.Externalizable; import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; import java.util.List; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicBoolean; import org.drools.core.RuleBaseConfiguration; import org.drools.core.common.InternalFactHandle; import org.drools.core.common.InternalWorkingMemory; import org.drools.core.common.Memory; import org.drools.core.common.MemoryFactory; import org.drools.core.common.WorkingMemoryAction; import org.drools.core.marshalling.impl.MarshallerReaderContext; import org.drools.core.marshalling.impl.MarshallerWriteContext; import org.drools.core.marshalling.impl.ProtobufMessages; import org.drools.core.phreak.PropagationEntry; import org.drools.core.reteoo.builder.BuildContext; import org.drools.core.spi.PropagationContext; import org.drools.core.util.bitmask.BitMask; import org.drools.core.util.bitmask.EmptyBitMask; /** * A node that will add the propagation to the working memory actions queue, * in order to allow multiple threads to concurrently assert objects to multiple * entry points. */ public class PropagationQueuingNode extends ObjectSource implements ObjectSinkNode, MemoryFactory<PropagationQueuingNode.PropagationQueueingNodeMemory> { private static final long serialVersionUID = 510l; // should we make this one configurable? private static final int PROPAGATION_SLICE_LIMIT = 1000; private ObjectSinkNode previousObjectSinkNode; private ObjectSinkNode nextObjectSinkNode; private PropagateAction action; public PropagationQueuingNode() { } /** * Construct a <code>PropagationQueuingNode</code> that will queue up * propagations until it the engine reaches a safe propagation point, * when all the queued facts are propagated. */ public PropagationQueuingNode(final int id, final ObjectSource objectSource, final BuildContext context) { super( id, context.getPartitionId(), context.getKnowledgeBase().getConfiguration().isMultithreadEvaluation(), objectSource, context.getKnowledgeBase().getConfiguration().getAlphaNodeHashingThreshold() ); this.action = new PropagateAction( this ); initDeclaredMask(context); hashcode = calculateHashCode(); } @Override public BitMask calculateDeclaredMask(List<String> settableProperties) { return EmptyBitMask.get(); } public void readExternal( ObjectInput in ) throws IOException, ClassNotFoundException { super.readExternal( in ); action = (PropagateAction) in.readObject(); } public void writeExternal( ObjectOutput out ) throws IOException { super.writeExternal( out ); out.writeObject( action ); } public short getType() { return NodeTypeEnums.PropagationQueuingNode; } public void updateSink( ObjectSink sink, PropagationContext context, InternalWorkingMemory workingMemory ) { final PropagationQueueingNodeMemory memory = workingMemory.getNodeMemory( this ); // this is just sanity code. We may remove it in the future, but keeping it for now. if ( !memory.isEmpty() ) { throw new RuntimeException( "Error updating sink. Not safe to update sink as the PropagatingQueueingNode memory is not empty at node: " + this.toString() ); } // as this node is simply a queue, ask object source to update the child sink directly this.source.updateSink( sink, context, workingMemory ); } public void attach( BuildContext context ) { this.source.addObjectSink( this ); // this node does not require update, so nothing else to do. } public ObjectSinkNode getNextObjectSinkNode() { return this.nextObjectSinkNode; } public ObjectSinkNode getPreviousObjectSinkNode() { return this.previousObjectSinkNode; } public void setNextObjectSinkNode( ObjectSinkNode next ) { this.nextObjectSinkNode = next; } public void setPreviousObjectSinkNode( ObjectSinkNode previous ) { this.previousObjectSinkNode = previous; } public boolean isObjectMemoryEnabled() { return true; } public void assertObject( InternalFactHandle factHandle, PropagationContext context, InternalWorkingMemory workingMemory ) { final PropagationQueueingNodeMemory memory = workingMemory.getNodeMemory( this ); memory.addAction( new AssertAction( factHandle, context ) ); // if not queued yet, we need to queue it up if ( memory.isQueued().compareAndSet( false, true ) ) { workingMemory.queueWorkingMemoryAction( this.action ); } } public void retractObject( InternalFactHandle handle, PropagationContext context, InternalWorkingMemory workingMemory ) { final PropagationQueueingNodeMemory memory = workingMemory.getNodeMemory( this ); memory.addAction( new RetractAction( handle, context ) ); // if not queued yet, we need to queue it up if ( memory.isQueued().compareAndSet( false, true ) ) { workingMemory.queueWorkingMemoryAction( this.action ); } } public void modifyObject(InternalFactHandle factHandle, ModifyPreviousTuples modifyPreviousTuples, PropagationContext context, InternalWorkingMemory workingMemory) { final PropagationQueueingNodeMemory memory = workingMemory.getNodeMemory( this ); for ( ObjectSink s : this.sink.getSinks() ) { BetaNode betaNode = (BetaNode) s; RightTuple rightTuple = modifyPreviousTuples.peekRightTuple(partitionId); while ( rightTuple != null && rightTuple.getInputOtnId().before( betaNode.getRightInputOtnId() ) ) { modifyPreviousTuples.removeRightTuple(partitionId); // we skipped this node, due to alpha hashing, so retract now rightTuple.retractTuple( context, workingMemory ); rightTuple = modifyPreviousTuples.peekRightTuple(partitionId); } if ( rightTuple != null && rightTuple.getInputOtnId().equals( betaNode.getRightInputOtnId() ) ) { modifyPreviousTuples.removeRightTuple(partitionId); rightTuple.reAdd(); if ( context.getModificationMask().intersects( betaNode.getRightInferredMask() ) ) { // RightTuple previously existed, so continue as modify memory.addAction( new ModifyToSinkAction( rightTuple, context, betaNode ) ); } } else { if ( context.getModificationMask().intersects( betaNode.getRightInferredMask() ) ) { // RightTuple does not exist for this node, so create and continue as assert memory.addAction( new AssertToSinkAction( factHandle, context, betaNode ) ); } } } // if not queued yet, we need to queue it up if ( memory.isQueued().compareAndSet( false, true ) ) { workingMemory.queueWorkingMemoryAction( this.action ); } } public void byPassModifyToBetaNode (final InternalFactHandle factHandle, final ModifyPreviousTuples modifyPreviousTuples, final PropagationContext context, final InternalWorkingMemory workingMemory) { modifyObject( factHandle, modifyPreviousTuples, context, workingMemory ); } /** * Propagate all queued actions (asserts and retracts). * <p/> * This method implementation is based on optimistic behavior to avoid the * use of locks. There may eventually be a minimum wasted effort, but overall * it will be better than paying for the lock's cost. */ public void propagateActions( InternalWorkingMemory workingMemory ) { final PropagationQueueingNodeMemory memory = workingMemory.getNodeMemory( this ); // first we clear up the action queued flag memory.isQueued().compareAndSet( true, false ); // we limit the propagation to avoid a hang when this queue is never empty Action next; for ( int counter = 0; counter < PROPAGATION_SLICE_LIMIT; counter++ ) { next = memory.getNextAction(); if ( next != null ) { next.execute( this.sink, workingMemory ); } else { break; } } if ( memory.hasNextAction() && memory.isQueued().compareAndSet( false, true ) ) { // add action to the queue again. workingMemory.queueWorkingMemoryAction( this.action ); } } public void setObjectMemoryEnabled( boolean objectMemoryOn ) { throw new UnsupportedOperationException( "PropagationQueueingNode must have its node memory enabled." ); } public PropagationQueueingNodeMemory createMemory(RuleBaseConfiguration config, InternalWorkingMemory wm) { return new PropagationQueueingNodeMemory(); } public int calculateHashCode() { return this.source.hashCode(); } @Override public boolean equals(final Object object) { return this == object || ( internalEquals( object ) && this.source.thisNodeEquals( ((PropagationQueuingNode)object).source ) ); } @Override protected boolean internalEquals( Object object ) { if ( object == null || !(object instanceof PropagationQueuingNode) || this.hashCode() != object.hashCode() ) { return false; } return true; } /** * Memory implementation for the node */ public static class PropagationQueueingNodeMemory implements Memory { private static final long serialVersionUID = 7372028632974484023L; private ConcurrentLinkedQueue<Action> queue; // "singleton" action - there is one of this for each node in each working memory private AtomicBoolean isQueued; public PropagationQueueingNodeMemory() { super(); this.queue = new ConcurrentLinkedQueue<Action>(); this.isQueued = new AtomicBoolean( false ); } public boolean isEmpty() { return this.queue.isEmpty(); } public void addAction( Action action ) { this.queue.add( action ); } public Action getNextAction() { return this.queue.poll(); } public boolean hasNextAction() { return this.queue.peek() != null; } public AtomicBoolean isQueued() { return isQueued; } public long getSize() { return this.queue.size(); } public short getNodeType() { return NodeTypeEnums.PropagationQueueingNode; } public Memory getPrevious() { throw new UnsupportedOperationException(); } public void setPrevious(Memory previous) { throw new UnsupportedOperationException(); } public void setNext(Memory next) { throw new UnsupportedOperationException(); } public Memory getNext() { throw new UnsupportedOperationException(); } public SegmentMemory getSegmentMemory() { return null; } public void setSegmentMemory(SegmentMemory segmentMemory) { throw new UnsupportedOperationException(); } public void nullPrevNext() { throw new UnsupportedOperationException(); } public void reset() { queue.clear(); isQueued.set(false); } } private static abstract class Action implements Externalizable { protected InternalFactHandle handle; protected PropagationContext context; public Action(InternalFactHandle handle, PropagationContext context) { super(); this.handle = handle; this.context = context; } public void readExternal( ObjectInput in ) throws IOException, ClassNotFoundException { handle = (InternalFactHandle) in.readObject(); context = (PropagationContext) in.readObject(); } public void writeExternal( ObjectOutput out ) throws IOException { out.writeObject( handle ); out.writeObject( context ); } public abstract void execute( final ObjectSinkPropagator sink, final InternalWorkingMemory workingMemory ); } private static class AssertAction extends Action { private static final long serialVersionUID = -8478488926430845209L; public AssertAction(final InternalFactHandle handle, final PropagationContext context) { super( handle, context ); } public void execute( final ObjectSinkPropagator sink, final InternalWorkingMemory workingMemory ) { sink.propagateAssertObject( this.handle, this.context, workingMemory ); context.evaluateActionQueue( workingMemory ); } } private static class AssertToSinkAction extends Action { private static final long serialVersionUID = -8478488926430845209L; private ObjectSink nodeSink; public AssertToSinkAction(final InternalFactHandle handle, final PropagationContext context, final ObjectSink sink) { super( handle, context ); nodeSink = sink; } public void execute( final ObjectSinkPropagator sink, final InternalWorkingMemory workingMemory ) { nodeSink.assertObject( this.handle, this.context, workingMemory ); context.evaluateActionQueue( workingMemory ); } @Override public void readExternal( ObjectInput in ) throws IOException, ClassNotFoundException { super.readExternal( in ); nodeSink = (ObjectSink) in.readObject(); } @Override public void writeExternal( ObjectOutput out ) throws IOException { super.writeExternal( out ); out.writeObject( nodeSink ); } } private static class RetractAction extends Action { private static final long serialVersionUID = -84784886430845209L; public RetractAction(final InternalFactHandle handle, final PropagationContext context) { super( handle, context ); } public void execute( final ObjectSinkPropagator sink, final InternalWorkingMemory workingMemory ) { this.handle.forEachRightTuple( rt -> rt.retractTuple( context, workingMemory ) ); this.handle.clearRightTuples(); this.handle.forEachLeftTuple( lt -> lt.retractTuple( context, workingMemory ) ); this.handle.clearLeftTuples(); context.evaluateActionQueue( workingMemory ); } } private static class ModifyToSinkAction extends Action { private static final long serialVersionUID = -8478488926430845209L; private RightTupleSink nodeSink; private RightTuple rightTuple; public ModifyToSinkAction(final RightTuple rightTuple, final PropagationContext context, final RightTupleSink nodeSink) { super( rightTuple.getFactHandle(), context ); this.nodeSink = nodeSink; this.rightTuple = rightTuple; } public void execute( final ObjectSinkPropagator sink, final InternalWorkingMemory workingMemory ) { nodeSink.modifyRightTuple( rightTuple, context, workingMemory ); context.evaluateActionQueue( workingMemory ); } @Override public void readExternal( ObjectInput in ) throws IOException, ClassNotFoundException { super.readExternal( in ); nodeSink = (RightTupleSink) in.readObject(); rightTuple = (RightTuple) in.readObject(); } @Override public void writeExternal( ObjectOutput out ) throws IOException { super.writeExternal( out ); out.writeObject( nodeSink ); out.writeObject( rightTuple ); } } /** * This is the action that is added to the working memory actions queue, so that * this node propagation can be triggered at a safe point */ public static class PropagateAction extends PropagationEntry.AbstractPropagationEntry implements WorkingMemoryAction { private static final long serialVersionUID = 6765029029501617115L; private PropagationQueuingNode node; public PropagateAction() { } public PropagateAction(PropagationQueuingNode node) { this.node = node; } public PropagateAction(MarshallerReaderContext context) throws IOException { this.node = (PropagationQueuingNode) context.sinks.get( context.readInt() ); } public PropagateAction(MarshallerReaderContext context, ProtobufMessages.ActionQueue.Action _action) { this.node = (PropagationQueuingNode) context.sinks.get( _action.getPropagate().getNodeId() ); } public ProtobufMessages.ActionQueue.Action serialize( MarshallerWriteContext context ) { return ProtobufMessages.ActionQueue.Action.newBuilder() .setType( ProtobufMessages.ActionQueue.ActionType.PROPAGATE ) .setPropagate( ProtobufMessages.ActionQueue.Propagate.newBuilder() .setNodeId( node.getId() ) .build() ) .build(); } public void execute( InternalWorkingMemory workingMemory ) { this.node.propagateActions( workingMemory ); } } }