package org.nd4j.aeron.ipc; import io.aeron.logbuffer.FragmentHandler; import io.aeron.logbuffer.Header; import lombok.extern.slf4j.Slf4j; import org.agrona.DirectBuffer; import org.nd4j.aeron.ipc.chunk.ChunkAccumulator; import org.nd4j.aeron.ipc.chunk.InMemoryChunkAccumulator; import org.nd4j.aeron.ipc.chunk.NDArrayMessageChunk; import java.nio.ByteBuffer; import java.nio.ByteOrder; /** * NDArray fragment handler * for listening to an aeron queue * * @author Adam Gibson */ @Slf4j public class NDArrayFragmentHandler implements FragmentHandler { private NDArrayCallback ndArrayCallback; private ChunkAccumulator chunkAccumulator = new InMemoryChunkAccumulator(); public NDArrayFragmentHandler(NDArrayCallback ndArrayCallback) { this.ndArrayCallback = ndArrayCallback; } /** * Callback for handling * fragments of data being read from a log. * * @param buffer containing the data. * @param offset at which the data begins. * @param length of the data in bytes. * @param header representing the meta data for the data. */ @Override public void onFragment(DirectBuffer buffer, int offset, int length, Header header) { ByteBuffer byteBuffer = buffer.byteBuffer(); boolean byteArrayInput = false; if (byteBuffer == null) { byteArrayInput = true; byte[] destination = new byte[length]; ByteBuffer wrap = ByteBuffer.wrap(buffer.byteArray()); wrap.get(destination, offset, length); byteBuffer = ByteBuffer.wrap(destination).order(ByteOrder.nativeOrder()); } //only applicable for direct buffers where we don't wrap the array if (!byteArrayInput) { byteBuffer.position(offset); byteBuffer.order(ByteOrder.nativeOrder()); } int messageTypeIndex = byteBuffer.getInt(); if (messageTypeIndex >= NDArrayMessage.MessageType.values().length) throw new IllegalStateException( "Illegal index on message type. Likely corrupt message. Please check the serialization of the bytebuffer. Input was bytebuffer: " + byteArrayInput); NDArrayMessage.MessageType messageType = NDArrayMessage.MessageType.values()[messageTypeIndex]; if (messageType == NDArrayMessage.MessageType.CHUNKED) { NDArrayMessageChunk chunk = NDArrayMessageChunk.fromBuffer(byteBuffer, messageType); if (chunk.getNumChunks() < 1) throw new IllegalStateException("Found invalid number of chunks " + chunk.getNumChunks() + " on chunk index " + chunk.getChunkIndex()); chunkAccumulator.accumulateChunk(chunk); log.info("Number of chunks " + chunk.getNumChunks() + " and number of chunks " + chunk.getNumChunks() + " for id " + chunk.getId() + " is " + chunkAccumulator.numChunksSoFar(chunk.getId())); if (chunkAccumulator.allPresent(chunk.getId())) { NDArrayMessage message = chunkAccumulator.reassemble(chunk.getId()); ndArrayCallback.onNDArrayMessage(message); } } else { NDArrayMessage message = NDArrayMessage.fromBuffer(buffer, offset); ndArrayCallback.onNDArrayMessage(message); } } }