package org.deeplearning4j.ui.module.remote;
import com.fasterxml.jackson.databind.JsonNode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.api.storage.*;
import org.deeplearning4j.ui.api.FunctionType;
import org.deeplearning4j.ui.api.HttpMethod;
import org.deeplearning4j.ui.api.Route;
import org.deeplearning4j.ui.api.UIModule;
import org.nd4j.shade.jackson.core.type.TypeReference;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import play.mvc.Http;
import play.mvc.Result;
import play.mvc.Results;
import javax.xml.bind.DatatypeConverter;
import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;
import static play.mvc.Http.Context.Implicit.request;
/**
*
* Used to receive UI updates remotely.
* Used in conjunction with {@link org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter}, which posts to the UI.
* UI information is then deserialized and routed to the specified StatsStorageRouter, which may (or may not)
* be attached to the UI
*
* @author Alex Black
*/
@Slf4j
public class RemoteReceiverModule implements UIModule {
private AtomicBoolean enabled = new AtomicBoolean(false);
private StatsStorageRouter statsStorage;
public void setEnabled(boolean enabled) {
this.enabled.set(enabled);
if (!enabled) {
this.statsStorage = null;
}
}
public boolean isEnabled() {
return enabled.get() && this.statsStorage != null;
}
public void setStatsStorage(StatsStorageRouter statsStorage) {
this.statsStorage = statsStorage;
}
@Override
public List<String> getCallbackTypeIDs() {
return Collections.emptyList();
}
@Override
public List<Route> getRoutes() {
Route r = new Route("/remoteReceive", HttpMethod.POST, FunctionType.Supplier, this::receiveData);
return Collections.singletonList(r);
}
@Override
public void reportStorageEvents(Collection<StatsStorageEvent> events) {
//No op
}
@Override
public void onAttach(StatsStorage statsStorage) {
//No op
}
@Override
public void onDetach(StatsStorage statsStorage) {
//No op
}
private Result receiveData() {
if (!enabled.get()) {
return Results.forbidden(
"UI server remote listening is currently disabled. Use UIServer.getInstance().enableRemoteListener()");
}
if (statsStorage == null) {
return Results.internalServerError(
"UI Server remote listener: no StatsStorage instance is set/available to store results");
}
JsonNode jn = request().body().asJson();
JsonNode type = jn.get("type");
JsonNode dataClass = jn.get("class");
JsonNode data = jn.get("data");
if (type == null || dataClass == null || data == null) {
log.warn("Received incorrectly formatted data from remote listener (has type = " + (type != null)
+ ", has data class = " + (dataClass != null) + ", has data = " + (data != null) + ")");
return Results.badRequest("Received incorrectly formatted data");
}
String dc = dataClass.asText();
String content = data.asText();
switch (type.asText().toLowerCase()) {
case "metadata":
StorageMetaData meta = getMetaData(dc, content);
if (meta != null) {
statsStorage.putStorageMetaData(meta);
}
break;
case "staticinfo":
Persistable staticInfo = getPersistable(dc, content);
if (staticInfo != null) {
statsStorage.putStaticInfo(staticInfo);
}
break;
case "update":
Persistable update = getPersistable(dc, content);
if (update != null) {
statsStorage.putUpdate(update);
}
break;
default:
}
return Results.ok("Receiver got data: ");
}
private StorageMetaData getMetaData(String dataClass, String content) {
StorageMetaData meta;
try {
Class<?> c = Class.forName(dataClass);
if (StorageMetaData.class.isAssignableFrom(c)) {
meta = (StorageMetaData) c.newInstance();
} else {
log.warn("Skipping invalid remote data: class {} in not an instance of {}", dataClass,
StorageMetaData.class.getName());
return null;
}
} catch (Exception e) {
log.warn("Skipping invalid remote data: exception encountered for class {}", dataClass, e);
return null;
}
try {
byte[] bytes = DatatypeConverter.parseBase64Binary(content);
meta.decode(bytes);
} catch (Exception e) {
log.warn("Skipping invalid remote UI data: exception encountered when deserializing data", e);
return null;
}
return meta;
}
private Persistable getPersistable(String dataClass, String content) {
Persistable p;
try {
Class<?> c = Class.forName(dataClass);
if (Persistable.class.isAssignableFrom(c)) {
p = (Persistable) c.newInstance();
} else {
log.warn("Skipping invalid remote data: class {} in not an instance of {}", dataClass,
Persistable.class.getName());
return null;
}
} catch (Exception e) {
log.warn("Skipping invalid remote UI data: exception encountered for class {}", dataClass, e);
return null;
}
try {
byte[] bytes = DatatypeConverter.parseBase64Binary(content);
p.decode(bytes);
} catch (Exception e) {
log.warn("Skipping invalid remote data: exception encountered when deserializing data", e);
return null;
}
return p;
}
}