package org.nd4j.aeron.ndarrayholder; import lombok.NoArgsConstructor; import org.nd4j.aeron.ipc.NDArrayHolder; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; /** * An in meory ndarray holder * * @author Adam Gibson */ @NoArgsConstructor public class InMemoryNDArrayHolder implements NDArrayHolder { private AtomicReference<INDArray> arr = new AtomicReference<>(); private AtomicInteger totalUpdates = new AtomicInteger(0); public InMemoryNDArrayHolder(int[] shape) { setArray(Nd4j.zeros(shape)); } public InMemoryNDArrayHolder(INDArray arr) { setArray(arr); } /** * Set the ndarray * * @param arr the ndarray for this holder * to use */ @Override public void setArray(INDArray arr) { if (this.arr.get() == null) this.arr.set(arr); } /** * The number of updates * that have been sent to this older. * * @return */ @Override public int totalUpdates() { return totalUpdates.get(); } /** * Retrieve an ndarray * * @return */ @Override public INDArray get() { return arr.get(); } /** * Retrieve a partial view of the ndarray. * This method uses tensor along dimension internally * Note this will call dup() * * @param idx the index of the tad to get * @param dimensions the dimensions to use * @return the tensor along dimension based on the index and dimensions * from the master array. */ @Override public INDArray getTad(int idx, int... dimensions) { return arr.get().tensorAlongDimension(idx, dimensions); } }