package org.nd4j.aeron.ipc; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; import org.agrona.DirectBuffer; import org.agrona.concurrent.UnsafeBuffer; import org.apache.commons.lang3.tuple.Pair; import org.nd4j.aeron.ipc.chunk.NDArrayMessageChunk; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.io.Serializable; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.time.Instant; import java.time.ZoneOffset; import java.time.ZonedDateTime; import java.util.UUID; /** * A message sent over the wire for ndarrays * includ ing the timestamp sent (in nanoseconds), * index (for tensor along dimension view based updates) * and dimensions. * * Fields: * arr: Using {@link AeronNDArraySerde#toArray(DirectBuffer, int)} we extract the array from a buffer * sent: the timestamp in milliseconds of when the message was sent (UTC timezone) - use {@link NDArrayMessage#getCurrentTimeUtc()} * when sending a message * index: the index of the tensor along dimension for update (use -1 if there is no index, eg: when you are going to use the whole array) * dimensions: the dimensions to do for a tensoralongdimension update, if you intend on updating the whole array send: new int[]{ -1} which * will indicate to use the whole array for an update. * * * @author Adam Gibson */ @Data @Builder @AllArgsConstructor @NoArgsConstructor public class NDArrayMessage implements Serializable { private INDArray arr; private long sent; private long index; private int[] dimensions; private byte[] chunk; private int numChunks = 0; //default dimensions: a 1 length array of -1 means use the whole array for an update. private static int[] WHOLE_ARRAY_UPDATE = {-1}; //represents the constant for indicating using the whole array for an update (-1) private static int WHOLE_ARRAY_INDEX = -1; public enum MessageValidity { VALID, NULL_VALUE, INCONSISTENT_DIMENSIONS } public enum MessageType { CHUNKED, WHOLE } /** * Determine the number of chunks * @param message * @param chunkSize * @return */ public static int numChunksForMessage(NDArrayMessage message, int chunkSize) { int sizeOfMessage = NDArrayMessage.byteBufferSizeForMessage(message); int numMessages = sizeOfMessage / chunkSize; //increase by 1 for padding if (numMessages * chunkSize < sizeOfMessage) numMessages++; return numMessages; } /** * Create an array of messages to send * based on a specified chunk size * @param arrayMessage * @param chunkSize * @return */ public static NDArrayMessage[] chunkedMessages(NDArrayMessage arrayMessage, int chunkSize) { int sizeOfMessage = NDArrayMessage.byteBufferSizeForMessage(arrayMessage) - 4; int numMessages = sizeOfMessage / chunkSize; ByteBuffer direct = NDArrayMessage.toBuffer(arrayMessage).byteBuffer(); NDArrayMessage[] ret = new NDArrayMessage[numMessages]; for (int i = 0; i < numMessages; i++) { byte[] chunk = new byte[chunkSize]; direct.get(chunk, i * chunkSize, chunkSize); ret[i] = NDArrayMessage.builder().chunk(chunk).numChunks(numMessages).build(); } return ret; } /** * Prepare a whole array update * which includes the default dimensions * for indicating updating * the whole array (a 1 length int array with -1 as its only element) * -1 representing the dimension * @param arr * @return */ public static NDArrayMessage wholeArrayUpdate(INDArray arr) { return NDArrayMessage.builder().arr(arr).dimensions(WHOLE_ARRAY_UPDATE).index(WHOLE_ARRAY_INDEX) .sent(getCurrentTimeUtc()).build(); } /** * Factory method for creating an array * to send now (uses now in utc for the timestamp). * Note that this method will throw an * {@link IllegalArgumentException} if an invalid message is passed in. * An invalid message is as follows: * An index of -1 and dimensions that are of greater length than 1 with an element that isn't -1 * * @param arr the array to send * @param dimensions the dimensions to use * @param index the index to use * @return the created */ public static NDArrayMessage of(INDArray arr, int[] dimensions, long index) { //allow null dimensions as long as index is -1 if (dimensions == null) { dimensions = WHOLE_ARRAY_UPDATE; } //validate index being built if (index > 0) { if (dimensions.length > 1 || dimensions.length == 1 && dimensions[0] != -1) throw new IllegalArgumentException( "Inconsistent message. Your index is > 0 indicating you want to send a whole ndarray message but your dimensions indicate you are trying to send a partial update. Please ensure you use a 1 length int array with negative 1 as an element or use NDArrayMesage.wholeArrayUpdate(ndarray) for creation instead"); } return NDArrayMessage.builder().index(index).dimensions(dimensions).sent(getCurrentTimeUtc()).arr(arr).build(); } /** * Returns if a message is valid or not based on a few simple conditions: * no null values * both index and the dimensions array must be -1 and of length 1 with an element of -1 in it * otherwise it is a valid message. * @param message the message to validate * @return 1 of: NULL_VALUE,INCONSISTENT_DIMENSIONS,VALID see {@link MessageValidity} */ public static MessageValidity validMessage(NDArrayMessage message) { if (message.getDimensions() == null || message.getArr() == null) return MessageValidity.NULL_VALUE; if (message.getIndex() != -1 && message.getDimensions().length == 1 && message.getDimensions()[0] != -1) return MessageValidity.INCONSISTENT_DIMENSIONS; return MessageValidity.VALID; } /** * Get the current time in utc in milliseconds * @return the current time in utc in * milliseconds */ public static long getCurrentTimeUtc() { Instant instant = Instant.now(); ZonedDateTime dateTime = instant.atZone(ZoneOffset.UTC); return dateTime.toInstant().toEpochMilli(); } /** * Returns the size needed in bytes * for a bytebuffer for a given ndarray message. * The formula is: * {@link AeronNDArraySerde#byteBufferSizeFor(INDArray)} * + size of dimension length (4) * + time stamp size (8) * + index size (8) * + 4 * message.getDimensions.length * @param message the message to get the length for * @return the size of the byte buffer for a message */ public static int byteBufferSizeForMessage(NDArrayMessage message) { int enumSize = 4; int nInts = 4 * message.getDimensions().length; int sizeofDimensionLength = 4; int timeStampSize = 8; int indexSize = 8; return enumSize + nInts + sizeofDimensionLength + timeStampSize + indexSize + AeronNDArraySerde.byteBufferSizeFor(message.getArr()); } /** * * Create an ndarray message from an array of buffers. * This array of buffers would be assembled by an * {@link io.aeron.logbuffer.FragmentHandler} * capable of merging these messages together. * Typically what happens is an {@link AeronNDArraySubscriber} * will track chunks being sent. * * Anytime a subscriber received an {@link MessageType#CHUNKED} * as a type it will store the buffer temporarily. * * @param chunks * @return */ public static NDArrayMessage fromChunks(NDArrayMessageChunk[] chunks) { int overAllCapacity = chunks[0].getChunkSize() * chunks.length; ByteBuffer all = ByteBuffer.allocateDirect(overAllCapacity).order(ByteOrder.nativeOrder()); for (int i = 0; i < chunks.length; i++) { ByteBuffer curr = chunks[i].getData(); if (curr.capacity() > chunks[0].getChunkSize()) { curr.position(0).limit(chunks[0].getChunkSize()); curr = curr.slice(); } all.put(curr); } //create an ndarray message from the given buffer UnsafeBuffer unsafeBuffer = new UnsafeBuffer(all); //rewind the buffer all.rewind(); return NDArrayMessage.fromBuffer(unsafeBuffer, 0); } /** * Returns an array of * message chunks meant to be sent * in parallel. * Each message chunk has the layout: * messageType * number of chunks * chunkSize * length of uuid * uuid * buffer index * actual raw data * @param message the message to turn into chunks * @param chunkSize the chunk size * @return an array of buffers */ public static NDArrayMessageChunk[] chunks(NDArrayMessage message, int chunkSize) { int numChunks = numChunksForMessage(message, chunkSize); NDArrayMessageChunk[] ret = new NDArrayMessageChunk[numChunks]; DirectBuffer wholeBuffer = NDArrayMessage.toBuffer(message); String messageId = UUID.randomUUID().toString(); for (int i = 0; i < ret.length; i++) { //data: only grab a chunk of the data ByteBuffer view = (ByteBuffer) wholeBuffer.byteBuffer().asReadOnlyBuffer().position(i * chunkSize); view.limit(Math.min(i * chunkSize + chunkSize, wholeBuffer.capacity())); view.order(ByteOrder.nativeOrder()); view = view.slice(); NDArrayMessageChunk chunk = NDArrayMessageChunk.builder().id(messageId).chunkSize(chunkSize) .numChunks(numChunks).messageType(MessageType.CHUNKED).chunkIndex(i).data(view).build(); //insert in to the array itself ret[i] = chunk; } return ret; } /** * Convert a message to a direct buffer. * See {@link NDArrayMessage#fromBuffer(DirectBuffer, int)} * for a description of the format for the buffer * @param message the message to convert * @return a direct byte buffer representing this message. */ public static DirectBuffer toBuffer(NDArrayMessage message) { ByteBuffer byteBuffer = ByteBuffer.allocateDirect(byteBufferSizeForMessage(message)).order(ByteOrder.nativeOrder()); //declare message type byteBuffer.putInt(MessageType.WHOLE.ordinal()); //perform the ndarray put on the if (message.getArr().isCompressed()) { AeronNDArraySerde.doByteBufferPutCompressed(message.getArr(), byteBuffer, false); } else { AeronNDArraySerde.doByteBufferPutUnCompressed(message.getArr(), byteBuffer, false); } long sent = message.getSent(); long index = message.getIndex(); byteBuffer.putLong(sent); byteBuffer.putLong(index); byteBuffer.putInt(message.getDimensions().length); for (int i = 0; i < message.getDimensions().length; i++) { byteBuffer.putInt(message.getDimensions()[i]); } //rewind the buffer before putting it in to the unsafe buffer //note that we set rewind to false in the do byte buffer put methods byteBuffer.rewind(); return new UnsafeBuffer(byteBuffer); } /** * Convert a direct buffer to an ndarray * message. * The format of the byte buffer is: * ndarray * time * index * dimension length * dimensions * * We use {@link AeronNDArraySerde#toArrayAndByteBuffer(DirectBuffer, int)} * to read in the ndarray and just use normal {@link ByteBuffer#getInt()} and * {@link ByteBuffer#getLong()} to get the things like dimensions and index * and time stamp. * * * * @param buffer the buffer to convert * @param offset the offset to start at with the buffer - note that this * method call assumes that the message type is specified at the beginning of the buffer. * This means whatever offset you pass in will be increased by 4 (the size of an int) * @return the ndarray message based on this direct buffer. */ public static NDArrayMessage fromBuffer(DirectBuffer buffer, int offset) { //skip the message type Pair<INDArray, ByteBuffer> pair = AeronNDArraySerde.toArrayAndByteBuffer(buffer, offset + 4); INDArray arr = pair.getKey(); Nd4j.getCompressor().decompressi(arr); //use the rest of the buffer, of note here the offset is already set, we should only need to use ByteBuffer rest = pair.getRight(); long time = rest.getLong(); long index = rest.getLong(); //get the array next for dimensions int dimensionLength = rest.getInt(); if (dimensionLength <= 0) throw new IllegalArgumentException("Invalid dimension length " + dimensionLength); int[] dimensions = new int[dimensionLength]; for (int i = 0; i < dimensionLength; i++) dimensions[i] = rest.getInt(); return NDArrayMessage.builder().sent(time).arr(arr).index(index).dimensions(dimensions).build(); } }