/* * Copyright 2015 Red Hat, Inc. and/or its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.drools.core.rule; import org.drools.core.WorkingMemory; import org.drools.core.base.accumulators.MVELAccumulatorFunctionExecutor; import org.drools.core.common.InternalFactHandle; import org.drools.core.spi.Accumulator; import org.drools.core.spi.CompiledInvoker; import org.drools.core.spi.Tuple; import org.drools.core.spi.Wireable; import org.kie.internal.security.KiePolicyHelper; import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; import java.io.Serializable; import java.util.Arrays; public class MultiAccumulate extends Accumulate { private Accumulator[] accumulators; public MultiAccumulate() { } public MultiAccumulate(final RuleConditionElement source, final Declaration[] requiredDeclarations, final Accumulator[] accumulators ) { super(source, requiredDeclarations); this.accumulators = accumulators; } public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { super.readExternal(in); this.accumulators = new Accumulator[in.readInt()]; for ( int i = 0; i < this.accumulators.length; i++ ) { this.accumulators[i] = (Accumulator) in.readObject(); } } public void writeExternal(ObjectOutput out) throws IOException { super.writeExternal(out); out.writeInt( accumulators.length ); for ( Accumulator acc : accumulators ) { if ( acc instanceof CompiledInvoker) { out.writeObject( null ); } else { out.writeObject( acc ); } } } public boolean isMultiFunction() { return true; } public Accumulator[] getAccumulators() { return this.accumulators; } public Serializable[] createContext() { Serializable[] ctxs = new Serializable[this.accumulators.length]; for ( int i = 0; i < ctxs.length; i++ ) { ctxs[i] = this.accumulators[i].createContext(); } return ctxs; } public void init(final Object workingMemoryContext, final Object context, final Tuple leftTuple, final WorkingMemory workingMemory) { try { for ( int i = 0; i < this.accumulators.length; i++ ) { this.accumulators[i].init( ((Object[])workingMemoryContext)[i], ((Object[])context)[i], leftTuple, this.requiredDeclarations, workingMemory ); } } catch ( final Exception e ) { throw new RuntimeException( e ); } } public void accumulate(final Object workingMemoryContext, final Object context, final Tuple leftTuple, final InternalFactHandle handle, final WorkingMemory workingMemory) { try { for ( int i = 0; i < this.accumulators.length; i++ ) { this.accumulators[i].accumulate( ((Object[])workingMemoryContext)[i], ((Object[])context)[i], leftTuple, handle, this.requiredDeclarations, getInnerDeclarationCache(), workingMemory ); } } catch ( final Exception e ) { throw new RuntimeException( e ); } } public void reverse(final Object workingMemoryContext, final Object context, final Tuple leftTuple, final InternalFactHandle handle, final WorkingMemory workingMemory) { try { for ( int i = 0; i < this.accumulators.length; i++ ) { this.accumulators[i].reverse( ((Object[])workingMemoryContext)[i], ((Object[])context)[i], leftTuple, handle, this.requiredDeclarations, getInnerDeclarationCache(), workingMemory ); } } catch ( final Exception e ) { throw new RuntimeException( e ); } } public boolean supportsReverse() { boolean supports = true; for( Accumulator acc : this.accumulators ) { if( ! acc.supportsReverse() ) { supports = false; break; } } return supports; } public Object[] getResult(final Object workingMemoryContext, final Object context, final Tuple leftTuple, final WorkingMemory workingMemory) { try { Object[] results = new Object[this.accumulators.length]; for ( int i = 0; i < this.accumulators.length; i++ ) { results[i] = this.accumulators[i].getResult( ((Object[])workingMemoryContext)[i], ((Object[])context)[i], leftTuple, this.requiredDeclarations, workingMemory ); } return results; } catch ( final Exception e ) { throw new RuntimeException( e ); } } protected void replaceAccumulatorDeclaration(Declaration declaration, Declaration resolved) { for (Accumulator accumulator : accumulators) { if ( accumulator instanceof MVELAccumulatorFunctionExecutor ) { ( (MVELAccumulatorFunctionExecutor) accumulator ).replaceDeclaration( declaration, resolved ); } } } public MultiAccumulate clone() { RuleConditionElement clonedSource = source instanceof GroupElement ? ((GroupElement) source).cloneOnlyGroup() : source.clone(); MultiAccumulate clone = new MultiAccumulate( clonedSource, this.requiredDeclarations, this.accumulators ); registerClone(clone); return clone; } public Object[] createWorkingMemoryContext() { Object[] ctx = new Object[ this.accumulators.length ]; for( int i = 0; i < this.accumulators.length; i++ ) { ctx[i] = this.accumulators[i].createWorkingMemoryContext(); } return ctx; } public final class Wirer implements Wireable, Serializable { private static final long serialVersionUID = -9072646735174734614L; private final int index; public Wirer( int index ) { this.index = index; } public void wire( Object object ) { Accumulator accumulator = KiePolicyHelper.isPolicyEnabled() ? new Accumulator.SafeAccumulator((Accumulator) object) : (Accumulator) object; accumulators[index] = accumulator; for ( Accumulate clone : cloned ) { ((MultiAccumulate)clone).accumulators[index] = accumulator; } } } public int hashCode() { final int prime = 31; int result = 1; result = prime * result + Arrays.hashCode(accumulators); result = prime * result + Arrays.hashCode( requiredDeclarations ); result = prime * result + ((source == null) ? 0 : source.hashCode()); return result; } public boolean equals(Object obj) { if ( this == obj ) return true; if ( obj == null ) return false; if ( getClass() != obj.getClass() ) return false; MultiAccumulate other = (MultiAccumulate) obj; if ( !Arrays.equals( accumulators, other.accumulators ) ) return false; if ( !Arrays.equals( requiredDeclarations, other.requiredDeclarations ) ) return false; if ( source == null ) { if ( other.source != null ) return false; } else if ( !source.equals( other.source ) ) return false; return true; } }