/*
* Copyright 2010 Red Hat, Inc. and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.drools.core.common;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReferenceArray;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import org.drools.core.impl.InternalKnowledgeBase;
import org.drools.core.reteoo.SegmentMemory;
import org.kie.internal.runtime.StatefulKnowledgeSession;
/**
* A concurrent implementation for the node memories interface
*/
public class ConcurrentNodeMemories implements NodeMemories {
private AtomicReferenceArray<Memory> memories;
private final Lock lock = new ReentrantLock();
private final InternalKnowledgeBase kBase;
private final String unitName;
public ConcurrentNodeMemories( InternalKnowledgeBase kBase, String unitName ) {
this.kBase = kBase;
this.unitName = unitName;
this.memories = new AtomicReferenceArray<Memory>( this.kBase.getMemoryCount(unitName) );
}
public void clearNodeMemory( MemoryFactory node ) {
if ( peekNodeMemory(node.getMemoryId()) != null ) {
this.memories.set(node.getMemoryId(), null);
}
}
public void clear() {
this.memories = new AtomicReferenceArray<Memory>( this.kBase.getMemoryCount(unitName) );
}
public void resetAllMemories(StatefulKnowledgeSession session) {
InternalKnowledgeBase kBase = (InternalKnowledgeBase)session.getKieBase();
Set<SegmentMemory> smems = new HashSet<SegmentMemory>();
for (int i = 0; i < memories.length(); i++) {
Memory memory = memories.get(i);
if (memory != null) {
if (memory.getSegmentMemory() != null) {
smems.add(memory.getSegmentMemory());
}
memory.reset();
}
}
for (SegmentMemory smem : smems) {
smem.reset(kBase.getSegmentPrototype(smem));
if ( smem.isSegmentLinked() ) {
smem.notifyRuleLinkSegment((InternalWorkingMemory)session);
}
}
}
/**
* The implementation tries to delay locking as much as possible, by running
* some potentially unsafe operations out of the critical session. In case it
* fails the checks, it will move into the critical sessions and re-check everything
* before effectively doing any change on data structures.
*/
public Memory getNodeMemory(MemoryFactory node, InternalWorkingMemory wm) {
if( node.getMemoryId() >= this.memories.length() ) {
resize( node );
}
Memory memory = this.memories.get( node.getMemoryId() );
if( memory == null ) {
memory = createNodeMemory( node, wm );
}
return memory;
}
/**
* Checks if a memory does not exists for the given node and
* creates it.
*/
private Memory createNodeMemory( MemoryFactory node,
InternalWorkingMemory wm ) {
try {
this.lock.lock();
// need to try again in a synchronized code block to make sure
// it was not created yet
Memory memory = this.memories.get( node.getMemoryId() );
if( memory == null ) {
memory = node.createMemory( this.kBase.getConfiguration(), wm );
if( !this.memories.compareAndSet( node.getMemoryId(), null, memory ) ) {
memory = this.memories.get( node.getMemoryId() );
}
}
return memory;
} finally {
this.lock.unlock();
}
}
/**
* @param node
*/
private void resize( MemoryFactory node ) {
try {
this.lock.lock();
if( node.getMemoryId() >= this.memories.length() ) {
// adding some buffer for new nodes, so that we reduce array copies
int size = Math.max( this.kBase.getMemoryCount(unitName), node.getMemoryId() + 32 );
AtomicReferenceArray<Memory> newMem = new AtomicReferenceArray<Memory>( size );
for ( int i = 0; i < this.memories.length(); i++ ) {
newMem.set( i,
this.memories.get( i ) );
}
this.memories = newMem;
}
} finally {
this.lock.unlock();
}
}
public Memory peekNodeMemory(int memoryId ) {
if ( memoryId < this.memories.length() ) {
return this.memories.get( memoryId );
} else {
return null;
}
}
public int length() {
return this.memories.length();
}
}