package org.nd4j.parameterserver.distributed.messages;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.agrona.concurrent.UnsafeBuffer;
import org.apache.commons.lang3.SerializationUtils;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.nd4j.parameterserver.distributed.enums.NodeRole;
import org.nd4j.parameterserver.distributed.logic.completion.Clipboard;
import org.nd4j.parameterserver.distributed.logic.Storage;
import org.nd4j.parameterserver.distributed.training.TrainingDriver;
import org.nd4j.parameterserver.distributed.transport.Transport;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
/**
* Simple wrapper for multiple request messages OF THE SAME TYPE being stacked into single message
*
* @author raver119@gmail.com
*/
@Slf4j
public class Frame<T extends TrainingMessage> implements Serializable, Iterable<T>, VoidMessage {
@Getter(AccessLevel.PROTECTED)
@Setter(AccessLevel.PROTECTED)
protected List<T> list = new ArrayList<T>();
@Getter
protected long originatorId;
@Getter
@Setter
protected short targetId;
@Getter
@Setter
protected long taskId;
protected transient VoidConfiguration voidConfiguration;
protected transient Clipboard clipboard;
protected transient Transport transport;
protected transient Storage storage;
protected transient NodeRole role;
protected transient short shardIndex;
protected transient TrainingDriver<? extends TrainingMessage> trainer;
@Getter
@Setter(AccessLevel.PRIVATE)
protected transient int retransmitCount = 0;
protected Frame() {
}
public Frame(long taskId) {
this.taskId = taskId;
}
public Frame(@NonNull T message) {
this();
list.add(message);
}
@Override
public void setOriginatorId(long id) {
this.originatorId = id;
if (list != null)
list.forEach((msg) -> {
msg.setOriginatorId(this.getOriginatorId());
});
}
/**
* This method adds single TrainingMessage to this Frame
*
* PLEASE NOTE: This method is synchronized
* @param message
*/
public synchronized void stackMessage(@NonNull T message) {
stackMessageUnlocked(message);
}
private void stackMessageUnlocked(@NonNull T message) {
if (message.isJoinSupported()) {
int index = list.indexOf(message);
if (index >= 0)
list.get(index).joinMessage(message);
else {
message.setFrameId(this.getTaskId());
list.add(message);
}
} else {
message.setFrameId(this.getTaskId());
list.add(message);
}
}
/**
* This method adds multiple messages to this frame
*
* PLEASE NOTE: This method is synchronized
* @param messages
*/
public synchronized void stackMessages(@NonNull Collection<T> messages) {
for (T message : messages) {
stackMessageUnlocked(message);
}
}
/**
* This method adds multiple messages to this frame
*
* PLEASE NOTE: This method is synchronized
* @param messages
*/
public synchronized void stackMessages(T... messages) {
for (T message : messages) {
if (message != null)
stackMessageUnlocked(message);
}
}
public Collection<T> getMessages() {
return list;
}
public int size() {
return list.size();
}
@Override
public Iterator<T> iterator() {
return list.iterator();
}
@Override
public int getMessageType() {
return 3;
}
@Override
public byte[] asBytes() {
return SerializationUtils.serialize(this);
}
@Override
public UnsafeBuffer asUnsafeBuffer() {
return new UnsafeBuffer(asBytes());
}
@Override
public void attachContext(@NonNull VoidConfiguration voidConfiguration,
@NonNull TrainingDriver<? extends TrainingMessage> trainer, @NonNull Clipboard clipboard,
@NonNull Transport transport, @NonNull Storage storage, @NonNull NodeRole role, short shardIndex) {
this.voidConfiguration = voidConfiguration;
this.clipboard = clipboard;
this.transport = transport;
this.storage = storage;
this.role = role;
this.shardIndex = shardIndex;
this.trainer = trainer;
}
@Override
public void extractContext(@NonNull BaseVoidMessage message) {
this.voidConfiguration = message.voidConfiguration;
this.clipboard = message.clipboard;
this.transport = message.transport;
this.storage = message.storage;
this.role = message.role;
this.shardIndex = message.shardIndex;
this.trainer = message.trainer;
this.originatorId = message.originatorId;
}
@Override
public void processMessage() {
// log.info("Processing frame {} of {} messages... Originator: {}", this.getTaskId(), list.size(), originatorId);
// we register all messages first
// if(list == null || trainer == null)
// return;
if (trainer != null && transport != null)
list.forEach((message) -> {
trainer.addCompletionHook(getOriginatorId(), getTaskId(), message.getTaskId());
});
//list.parallelStream().forEach((message) -> {
for (TrainingMessage message : list) {
if (trainer != null && transport != null)
message.attachContext(voidConfiguration, trainer, clipboard, transport, storage, role, shardIndex);
// if there's more then 1 round should be applied
for (int i = 0; i < message.getCounter(); i++) {
//log.info("Firing message {}; originator: {}; frameId: {}; taskId: {}", message.getClass().getSimpleName(), message.getOriginatorId(), message.getFrameId(), message.getTaskId());
message.processMessage();
}
} ;
}
@Override
public boolean isJoinSupported() {
return false;
}
@Override
public void joinMessage(VoidMessage message) {
// no-op
}
@Override
public boolean isBlockingMessage() {
return true;
}
@Override
public void incrementRetransmitCount() {
retransmitCount++;
}
}