package org.deeplearning4j.ui.module.histogram;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.ui.api.FunctionType;
import org.deeplearning4j.ui.api.HttpMethod;
import org.deeplearning4j.ui.api.Route;
import org.deeplearning4j.ui.api.UIModule;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.stats.api.StatsInitializationReport;
import org.deeplearning4j.ui.stats.api.StatsReport;
import org.deeplearning4j.ui.stats.api.StatsType;
import org.deeplearning4j.ui.stats.api.SummaryType;
import org.deeplearning4j.api.storage.Persistable;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.api.storage.StatsStorageEvent;
import org.deeplearning4j.ui.weights.beans.CompactModelAndGradient;
import play.libs.Json;
import play.mvc.Result;
import play.mvc.Results;
import java.util.*;
import static play.mvc.Results.ok;
/**
* Module for the HistogramIterationListener
*
* @author Alex Black
*/
@Slf4j
public class HistogramModule implements UIModule {
private Map<String, StatsStorage> knownSessionIDs = Collections.synchronizedMap(new LinkedHashMap<>());
@Override
public List<String> getCallbackTypeIDs() {
return Collections.singletonList(StatsListener.TYPE_ID);
}
@Override
public List<Route> getRoutes() {
Route r = new Route("/weights", HttpMethod.GET, FunctionType.Supplier,
() -> ok(org.deeplearning4j.ui.views.html.histogram.Histogram.apply()));
Route r2 = new Route("/weights/listSessions", HttpMethod.GET, FunctionType.Supplier,
() -> ok(Json.toJson(knownSessionIDs.keySet())));
Route r3 = new Route("/weights/updated/:sid", HttpMethod.GET, FunctionType.Function, this::getLastUpdateTime);
Route r4 = new Route("/weights/data/:sid", HttpMethod.GET, FunctionType.Function, this::processRequest);
return Arrays.asList(r, r2, r3, r4);
}
@Override
public void reportStorageEvents(Collection<StatsStorageEvent> events) {
log.trace("Received events: {}", events);
//We should only be getting relevant session IDs...
for (StatsStorageEvent sse : events) {
if (!knownSessionIDs.containsKey(sse.getSessionID())) {
knownSessionIDs.put(sse.getSessionID(), sse.getStatsStorage());
}
}
}
@Override
public void onAttach(StatsStorage statsStorage) {
for (String sessionID : statsStorage.listSessionIDs()) {
for (String typeID : statsStorage.listTypeIDsForSession(sessionID)) {
if (!StatsListener.TYPE_ID.equals(typeID))
continue;
knownSessionIDs.put(sessionID, statsStorage);
}
}
}
@Override
public void onDetach(StatsStorage statsStorage) {
for (String sessionID : statsStorage.listSessionIDs()) {
knownSessionIDs.remove(sessionID);
}
}
private Result getLastUpdateTime(String sessionID) {
return Results.ok(Json.toJson(System.currentTimeMillis()));
}
private Result processRequest(String sessionId) {
//TODO cache the relevant info and update, rather than querying StatsStorage and building from scratch each time
StatsStorage ss = knownSessionIDs.get(sessionId);
if (ss == null) {
return Results.notFound("Unknown session ID: " + sessionId);
}
List<String> workerIDs = ss.listWorkerIDsForSession(sessionId);
//TODO checks
StatsInitializationReport initReport = (StatsInitializationReport) ss.getStaticInfo(sessionId,
StatsListener.TYPE_ID, workerIDs.get(0));
if (initReport == null)
return Results.ok(Json.toJson(Collections.EMPTY_MAP));
String[] paramNames = initReport.getModelParamNames();
//Infer layer names from param names...
Set<String> layerNameSet = new LinkedHashSet<>();
for (String s : paramNames) {
String[] split = s.split("_");
if (!layerNameSet.contains(split[0])) {
layerNameSet.add(split[0]);
}
}
List<String> layerNameList = new ArrayList<>(layerNameSet);
List<Persistable> list = ss.getAllUpdatesAfter(sessionId, StatsListener.TYPE_ID, workerIDs.get(0), 0);
Collections.sort(list, (a, b) -> Long.compare(a.getTimeStamp(), b.getTimeStamp()));
List<Double> scoreList = new ArrayList<>(list.size());
List<Map<String, List<Double>>> meanMagHistoryParams = new ArrayList<>(); //List.get(i) -> layer i. Maps: parameter for the given layer
List<Map<String, List<Double>>> meanMagHistoryUpdates = new ArrayList<>(); //List.get(i) -> layer i. Maps: updates for the given layer
for (int i = 0; i < layerNameList.size(); i++) {
meanMagHistoryParams.add(new HashMap<>());
meanMagHistoryUpdates.add(new HashMap<>());
}
StatsReport last = null;
for (Persistable p : list) {
if (!(p instanceof StatsReport)) {
log.debug("Encountered unexpected type: {}", p);
continue;
}
StatsReport sp = (StatsReport) p;
scoreList.add(sp.getScore());
//Mean magnitudes
if (sp.hasSummaryStats(StatsType.Parameters, SummaryType.MeanMagnitudes)) {
updateMeanMagnitudeMaps(sp.getMeanMagnitudes(StatsType.Parameters), layerNameList,
meanMagHistoryParams);
}
if (sp.hasSummaryStats(StatsType.Updates, SummaryType.MeanMagnitudes)) {
updateMeanMagnitudeMaps(sp.getMeanMagnitudes(StatsType.Updates), layerNameList, meanMagHistoryUpdates);
}
last = sp;
}
Map<String, Map> newParams = getHistogram(last.getHistograms(StatsType.Parameters));
Map<String, Map> newGrad = getHistogram(last.getHistograms(StatsType.Updates));
double lastScore = (scoreList.size() == 0 ? 0.0 : scoreList.get(scoreList.size() - 1));
CompactModelAndGradient g = new CompactModelAndGradient();
g.setGradients(newGrad);
g.setParameters(newParams);
g.setScore(lastScore);
g.setScores(scoreList);
// g.setPath(subPath);
g.setUpdateMagnitudes(meanMagHistoryUpdates);
g.setParamMagnitudes(meanMagHistoryParams);
// g.setLayerNames(layerNames);
g.setLastUpdateTime(last.getTimeStamp());
return Results.ok(Json.toJson(g));
}
private void updateMeanMagnitudeMaps(Map<String, Double> current, List<String> layerNames,
List<Map<String, List<Double>>> history) {
for (Map.Entry<String, Double> entry : current.entrySet()) {
String key = entry.getKey();
String[] split = key.split("_");
int idx = layerNames.indexOf(split[0]);
Map<String, List<Double>> map = history.get(idx);
List<Double> l = map.get(key);
if (l == null) {
l = new ArrayList<>();
map.put(key, l);
}
l.add(entry.getValue());
}
}
private Map<String, Map> getHistogram(Map<String, org.deeplearning4j.ui.stats.api.Histogram> histograms) {
Map<String, Map> ret = new LinkedHashMap<>();
for (String s : histograms.keySet()) {
org.deeplearning4j.ui.stats.api.Histogram h = histograms.get(s);
String newName;
if (Character.isDigit(s.charAt(0)))
newName = "param_" + s;
else
newName = s;
Map<Number, Number> temp = new LinkedHashMap<>();
double min = h.getMin();
double max = h.getMax();
int n = h.getNBins();
double step = (max - min) / n;
int[] counts = h.getBinCounts();
for (int i = 0; i < n; i++) {
double binLoc = min + i * step + step / 2.0;
temp.put(binLoc, counts[i]);
}
ret.put(newName, temp);
}
return ret;
}
}