package org.nd4j.aeron.ipc;
import org.agrona.DirectBuffer;
import org.agrona.concurrent.UnsafeBuffer;
import org.apache.commons.lang3.tuple.Pair;
import org.bytedeco.javacpp.BytePointer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.compression.CompressionDescriptor;
import org.nd4j.linalg.factory.Nd4j;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
/**
* NDArray Serialization and
* de serialization class for
* aeron.
*
* This is a low level class
* specifically meant for speed.
*
* @author Adam Gibson
*/
public class AeronNDArraySerde {
/**
* Returns the byte buffer size for the given
* ndarray. This is an auxillary method
* for determining the size of the buffer
* size to allocate for sending an ndarray via
* the aeron media driver.
*
* The math break down for uncompressed is:
* 2 ints for rank of the array and an ordinal representing the data type of the data buffer
* The rest is in order:
* shape information
* data buffer
*
* The math break down for compressed is:
* 2 ints for rank and an ordinal representing the data type for the data buffer
*
* The rest is in order:
* shape information
* codec information
* data buffer
*
* @param arr the array to compute the size for
* @return the size of the byte buffer that was allocated
*/
public static int byteBufferSizeFor(INDArray arr) {
if (!arr.isCompressed()) {
ByteBuffer buffer = arr.data().pointer().asByteBuffer().order(ByteOrder.nativeOrder());
ByteBuffer shapeBuffer = arr.shapeInfoDataBuffer().pointer().asByteBuffer().order(ByteOrder.nativeOrder());
//2 four byte ints at the beginning
int twoInts = 8;
return twoInts + buffer.limit() + shapeBuffer.limit();
} else {
CompressedDataBuffer compressedDataBuffer = (CompressedDataBuffer) arr.data();
CompressionDescriptor descriptor = compressedDataBuffer.getCompressionDescriptor();
ByteBuffer codecByteBuffer = descriptor.toByteBuffer();
ByteBuffer buffer = arr.data().pointer().asByteBuffer().order(ByteOrder.nativeOrder());
ByteBuffer shapeBuffer = arr.shapeInfoDataBuffer().pointer().asByteBuffer().order(ByteOrder.nativeOrder());
int twoInts = 2 * 4;
return twoInts + buffer.limit() + shapeBuffer.limit() + codecByteBuffer.limit();
}
}
/**
* Convert an ndarray to an unsafe buffer
* for use by aeron
* @param arr the array to convert
* @return the unsafebuffer representation of this array
*/
public static UnsafeBuffer toBuffer(INDArray arr) {
//subset and get rid of 1 off non 1 element wise stride cases
if (arr.isView())
arr = arr.dup();
if (!arr.isCompressed()) {
ByteBuffer b3 = ByteBuffer.allocateDirect(byteBufferSizeFor(arr)).order(ByteOrder.nativeOrder());
doByteBufferPutUnCompressed(arr, b3, true);
return new UnsafeBuffer(b3);
}
//compressed array
else {
ByteBuffer b3 = ByteBuffer.allocateDirect(byteBufferSizeFor(arr)).order(ByteOrder.nativeOrder());
doByteBufferPutCompressed(arr, b3, true);
return new UnsafeBuffer(b3);
}
}
/**
* Setup the given byte buffer
* for serialization (note that this is for uncompressed INDArrays)
* 4 bytes int for rank
* 4 bytes for data type
* shape buffer
* data buffer
*
* @param arr the array to setup
* @param allocated the byte buffer to setup
* @param rewind whether to rewind the byte buffer or nt
*/
public static void doByteBufferPutUnCompressed(INDArray arr, ByteBuffer allocated, boolean rewind) {
ByteBuffer buffer = arr.data().pointer().asByteBuffer().order(ByteOrder.nativeOrder());
ByteBuffer shapeBuffer = arr.shapeInfoDataBuffer().pointer().asByteBuffer().order(ByteOrder.nativeOrder());
//2 four byte ints at the beginning
allocated.putInt(arr.rank());
//put data type next so its self describing
allocated.putInt(arr.data().dataType().ordinal());
allocated.put(shapeBuffer);
allocated.put(buffer);
if (rewind)
allocated.rewind();
}
/**
* Setup the given byte buffer
* for serialization (note that this is for compressed INDArrays)
* 4 bytes for rank
* 4 bytes for data type
* shape information
* codec information
* data type
*
* @param arr the array to setup
* @param allocated the byte buffer to setup
* @param rewind whether to rewind the byte buffer or not
*/
public static void doByteBufferPutCompressed(INDArray arr, ByteBuffer allocated, boolean rewind) {
CompressedDataBuffer compressedDataBuffer = (CompressedDataBuffer) arr.data();
CompressionDescriptor descriptor = compressedDataBuffer.getCompressionDescriptor();
ByteBuffer codecByteBuffer = descriptor.toByteBuffer();
ByteBuffer buffer = arr.data().pointer().asByteBuffer().order(ByteOrder.nativeOrder());
ByteBuffer shapeBuffer = arr.shapeInfoDataBuffer().pointer().asByteBuffer().order(ByteOrder.nativeOrder());
allocated.putInt(arr.rank());
//put data type next so its self describing
allocated.putInt(arr.data().dataType().ordinal());
//put shape next
allocated.put(shapeBuffer);
//put codec information next
allocated.put(codecByteBuffer);
//finally put the data
allocated.put(buffer);
if (rewind)
allocated.rewind();
}
/**
* Create an ndarray
* from the unsafe buffer.
* Note that if you are interacting with a buffer that specifies
* an {@link org.nd4j.aeron.ipc.NDArrayMessage.MessageType}
* then you must pass in an offset + 4.
* Adding 4 to the offset will cause the inter
* @param buffer the buffer to create the array from
* @return the ndarray derived from this buffer
*/
public static Pair<INDArray, ByteBuffer> toArrayAndByteBuffer(DirectBuffer buffer, int offset) {
ByteBuffer byteBuffer =
buffer.byteBuffer() == null
? ByteBuffer.allocateDirect(buffer.byteArray().length).put(buffer.byteArray())
.order(ByteOrder.nativeOrder())
: buffer.byteBuffer().order(ByteOrder.nativeOrder());
//bump the byte buffer to the proper position
byteBuffer.position(offset);
int rank = byteBuffer.getInt();
if (rank < 0)
throw new IllegalStateException("Found negative integer. Corrupt serialization?");
//get the shape buffer length to create the shape information buffer
int shapeBufferLength = Shape.shapeInfoLength(rank);
//create the ndarray shape information
DataBuffer shapeBuff = Nd4j.createBuffer(new int[shapeBufferLength]);
//compute the databuffer type from the index
DataBuffer.Type type = DataBuffer.Type.values()[byteBuffer.getInt()];
for (int i = 0; i < shapeBufferLength; i++) {
shapeBuff.put(i, byteBuffer.getInt());
}
//after the rank,data type, shape buffer (of length shape buffer length) * sizeof(int)
if (type != DataBuffer.Type.COMPRESSED) {
ByteBuffer slice = byteBuffer.slice();
//wrap the data buffer for the last bit
DataBuffer buff = Nd4j.createBuffer(slice, type, Shape.length(shapeBuff));
//advance past the data
int position = byteBuffer.position() + (buff.getElementSize() * (int) buff.length());
byteBuffer.position(position);
//create the final array
//TODO: see how to avoid dup here
INDArray arr = Nd4j.createArrayFromShapeBuffer(buff.dup(), shapeBuff.dup());
return Pair.of(arr, byteBuffer);
} else {
CompressionDescriptor compressionDescriptor = CompressionDescriptor.fromByteBuffer(byteBuffer);
ByteBuffer slice = byteBuffer.slice();
//ensure that we only deal with the slice of the buffer that is actually the data
BytePointer byteBufferPointer = new BytePointer(slice);
//create a compressed array based on the rest of the data left in the buffer
CompressedDataBuffer compressedDataBuffer =
new CompressedDataBuffer(byteBufferPointer, compressionDescriptor);
//TODO: see how to avoid dup()
INDArray arr = Nd4j.createArrayFromShapeBuffer(compressedDataBuffer.dup(), shapeBuff.dup());
//advance past the data
int compressLength = (int) compressionDescriptor.getCompressedLength();
byteBuffer.position(byteBuffer.position() + compressLength);
return Pair.of(arr, byteBuffer);
}
}
/**
* Create an ndarray
* from the unsafe buffer
* @param buffer the buffer to create the array from
* @return the ndarray derived from this buffer
*/
public static INDArray toArray(DirectBuffer buffer, int offset) {
return toArrayAndByteBuffer(buffer, offset).getLeft();
}
/**
* Create an ndarray
* from the unsafe buffer
* @param buffer the buffer to create the array from
* @return the ndarray derived from this buffer
*/
public static INDArray toArray(DirectBuffer buffer) {
return toArray(buffer, 0);
}
}