/*******************************************************************************
* Copyright 2015 Klaus Pfeiffer <klaus@allpiper.com>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
******************************************************************************/
package com.jfastnet.messages;
import com.jfastnet.State;
import com.jfastnet.exceptions.DeserialiseException;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.*;
import java.util.stream.Collectors;
import java.util.zip.DeflaterOutputStream;
import java.util.zip.InflaterInputStream;
/** Used for bigger messages to be transferred in parts.
* @author Klaus Pfeiffer - klaus@allpiper.com */
@Slf4j
public class MessagePart extends Message implements IDontFrame {
private static final int ZLIB_HEADER = 0x78;
/** Message size in bytes without bytes payload. */
public static final int MESSAGE_HEADER_SIZE = 120;
private static final List<MessagePart> EMPTY_MESSAGE_PARTS = Collections.unmodifiableList(new ArrayList<>());
/** Whether this is the last part to construct the message. */
boolean last;
/** The id servers the purpose to allow receiving of multiple different
* big messages. */
long id;
/** The number tells us which part of the message we received. */
int partNumber;
/** The payload of the message part. */
byte[] bytes;
private MessagePart(long id, int partNumber, byte[] bytes) {
this.id = id;
this.partNumber = partNumber;
this.bytes = bytes;
}
public static List<MessagePart> createFromMessage(State state, long id, Message message, int chunkSize) {
return createFromMessage(state, id, message, chunkSize, message.getReliableMode());
}
public static List<MessagePart> createFromMessage(@NonNull State state, long id, @NonNull Message message, int chunkSize, @NonNull ReliableMode reliableMode) {
state.getUdpPeer().createPayload(message);
// createPayload has to create a byte array
// Depends on the UDP peer if this is possible.
if (message.payload instanceof byte[]) {
byte[] bytes = (byte[]) message.payload;
if (state.getConfig().compressBigMessages) {
byte[] compressedBytes = compress(bytes);
if (compressedBytes == null) {
log.warn("Compression failed for message: {}", message);
log.info("Proceeding without compression.");
} else {
bytes = compressedBytes;
}
}
return createFromByteArray(id, bytes, chunkSize, reliableMode);
}
log.error("Message could not be created, because of missing byte array payload.");
return EMPTY_MESSAGE_PARTS;
}
public static List<MessagePart> createFromByteArray(long id, byte[] bytes, int chunkSize, @NonNull ReliableMode reliableMode) {
if (bytes == null) {
log.error("Byte array was null!");
return EMPTY_MESSAGE_PARTS;
}
if (bytes.length == 0) {
log.error("Byte array was empty!");
return EMPTY_MESSAGE_PARTS;
}
if (ReliableMode.UNRELIABLE.equals(reliableMode)) {
log.warn("Splitting of unreliable messages not supported!");
return EMPTY_MESSAGE_PARTS;
}
log.info("Create message with {} bytes and chunk size {}", bytes.length, chunkSize);
int from = 0;
int to = chunkSize;
int partNumber = 0;
List<MessagePart> messages = new ArrayList<>();
while (from < bytes.length) {
byte[] chunk = Arrays.copyOfRange(bytes, from, to);
if (ReliableMode.SEQUENCE_NUMBER.equals(reliableMode)) {
messages.add(new MessagePart(id, partNumber, chunk));
} else if (ReliableMode.ACK_PACKET.equals(reliableMode)) {
messages.add(new AckMessagePart(id, partNumber, chunk));
} else {
throw new UnsupportedOperationException("Reliable mode '" + reliableMode + "' not supported for message splitting!");
}
partNumber++;
from += chunkSize;
to += chunkSize;
}
messages.get(messages.size() - 1).last = true;
log.info("Created {} messages.", messages.size());
return messages;
}
@Override
public void process(Object context) {
log.trace("Part number {} of id {} received.", partNumber, id);
SortedMap<Long, SortedMap<Integer, MessagePart>> arrayBufferMap = getState().getByteArrayBufferMap();
SortedMap<Integer, MessagePart> byteArrayBuffer = arrayBufferMap.get(id);
if (byteArrayBuffer == null) {
byteArrayBuffer = new TreeMap<>();
arrayBufferMap.put(id, byteArrayBuffer);
}
byteArrayBuffer.put(partNumber, this);
if (allPartsReceived()) {
Collection<byte[]> values = byteArrayBuffer.values().stream().collect(Collectors.mapping(messagePart -> messagePart.bytes, Collectors.toList()));
log.info("Last of {} parts for splitted message received.", values.size());
ByteArrayOutputStream bos = new ByteArrayOutputStream();
try {
for (byte[] value : values) {
bos.write(value);
}
bos.flush();
byte[] byteArray = bos.toByteArray();
if (getConfig().compressBigMessages) {
byteArray = decompress(byteArray);
}
Message messageFromByteArray = getConfig().serialiser.deserialise(byteArray, 0, byteArray.length);
if (messageFromByteArray == null) {
log.error("Deserialised message was null! See previous errors.");
throw new DeserialiseException("Deserialised message was null! See previous errors.");
} else {
log.info("Message created: {}", messageFromByteArray);
messageFromByteArray.copyAttributesFrom(this);
getConfig().externalReceiver.receive(messageFromByteArray);
}
} catch (IOException e) {
log.error("Error writing byte array.", e);
} finally {
byteArrayBuffer.clear();
}
}
}
private boolean allPartsReceived() {
if (ReliableMode.SEQUENCE_NUMBER.equals(getReliableMode())) {
return last;
} else if (ReliableMode.ACK_PACKET.equals(getReliableMode())) {
SortedMap<Integer, MessagePart> byteArrayBuffer = getState().getByteArrayBufferMap().get(id);
if (byteArrayBuffer == null) {
log.trace("byteArrayBuffer == null");
return false;
}
Collection<MessagePart> messageParts = byteArrayBuffer.values();
// Check if the last part was already received
boolean hasLastPart = messageParts.stream().filter(messagePart -> messagePart.last).count() > 0L;
if (!hasLastPart) {
log.trace("!hasLastPart");
return false;
}
// Check if all required messages are received
int expectedPartNumber = 0;
for (MessagePart messagePart : messageParts) {
if (messagePart.partNumber != expectedPartNumber) {
log.trace("messagePart.partNumber != expectedPartNumber: {} != {}", messagePart.partNumber, expectedPartNumber);
return false;
}
expectedPartNumber++;
}
return true;
} else {
throw new UnsupportedOperationException("Unsupported reliable mode.");
}
}
public static byte[] compress(byte[] bytes) {
log.info("Compress byte array of size {}", bytes.length);
try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(bytes.length);
DeflaterOutputStream deflaterOutputStream = new DeflaterOutputStream(byteArrayOutputStream)) {
deflaterOutputStream.write(bytes);
deflaterOutputStream.close();
byteArrayOutputStream.close();
return byteArrayOutputStream.toByteArray();
} catch (IOException e) {
log.error("Couldn't compress byte array.", e);
}
return null;
}
public static byte[] decompress(byte[] bytes) {
if (!isCompressed(bytes)) {
log.warn("Tried to decompress uncompressed data. No decompressing will be done.");
return bytes;
}
try (ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(bytes);
InflaterInputStream inflaterInputStream = new InflaterInputStream(byteArrayInputStream)) {
ByteArrayOutputStream bout = new ByteArrayOutputStream(2048);
int b;
while ((b = inflaterInputStream.read()) != -1) {
bout.write(b);
}
inflaterInputStream.close();
bout.close();
return bout.toByteArray();
} catch (IOException e) {
log.error("Couldn't decompress byte array.", e);
}
return bytes;
}
/**
* ZLib magic headers:
* <pre>
* 78 01 - No Compression/low
* 78 9C - Default Compression
* 78 DA - Best Compression
* </pre>
* @return true, if data was compressed with the java default inflater (zlib)
*/
public static boolean isCompressed(byte[] data) {
return data.length > 0 && (data[0] & 0xff) == ZLIB_HEADER && (
(data[1] & 0xff) == 0x9c
|| (data[1] & 0xff) == 0x01
|| (data[1] & 0xff) == 0xda
);
}
/** MessagePart with ACK reliable mode. */
public static class AckMessagePart extends MessagePart {
private AckMessagePart(long id, int partNumber, byte[] bytes) {
super(id, partNumber, bytes);
}
@Override
public ReliableMode getReliableMode() {
return ReliableMode.ACK_PACKET;
}
}
}