package org.deeplearning4j.ui.module.train;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.api.storage.Persistable;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.api.storage.StatsStorageEvent;
import org.deeplearning4j.api.storage.StatsStorageListener;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.berkeley.Triple;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.ui.api.*;
import org.deeplearning4j.ui.i18n.I18NProvider;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.stats.api.Histogram;
import org.deeplearning4j.ui.stats.api.StatsInitializationReport;
import org.deeplearning4j.ui.stats.api.StatsReport;
import org.deeplearning4j.ui.stats.api.StatsType;
import org.deeplearning4j.ui.views.html.training.TrainingHelp;
import org.deeplearning4j.ui.views.html.training.TrainingModel;
import org.deeplearning4j.ui.views.html.training.TrainingOverview;
import org.deeplearning4j.ui.views.html.training.TrainingSystem;
import org.nd4j.linalg.learning.config.IUpdater;
import play.libs.Json;
import play.mvc.Result;
import play.mvc.Results;
import java.text.DateFormat;
import java.text.DecimalFormat;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import static play.mvc.Results.ok;
import static play.mvc.Results.redirect;
/**
* Main DL4J Training UI
*
* @author Alex Black
*/
@Slf4j
public class TrainModule implements UIModule {
public static final double NAN_REPLACEMENT_VALUE = 0.0; //UI front-end chokes on NaN in JSON
public static final int DEFAULT_MAX_CHART_POINTS = 512;
public static final String CHART_MAX_POINTS_PROPERTY = "org.deeplearning4j.ui.maxChartPoints";
private static final DecimalFormat df2 = new DecimalFormat("#.00");
private static DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
private enum ModelType {
MLN, CG, Layer
};
private final int maxChartPoints; //Technically, the way it's set up: won't exceed 2*maxChartPoints
private Map<String, StatsStorage> knownSessionIDs = Collections.synchronizedMap(new LinkedHashMap<>());
private String currentSessionID;
private int currentWorkerIdx;
private Map<String, AtomicInteger> workerIdxCount = Collections.synchronizedMap(new HashMap<>()); //Key: session ID
private Map<String, Map<Integer, String>> workerIdxToName = Collections.synchronizedMap(new HashMap<>()); //Key: session ID
private Map<String, Long> lastUpdateForSession = Collections.synchronizedMap(new HashMap<>());
public TrainModule() {
String maxChartPointsProp = System.getProperty(CHART_MAX_POINTS_PROPERTY);
int value = DEFAULT_MAX_CHART_POINTS;
if (maxChartPointsProp != null) {
try {
value = Integer.parseInt(maxChartPointsProp);
} catch (NumberFormatException e) {
log.warn("Invalid system property: {} = {}", CHART_MAX_POINTS_PROPERTY, maxChartPointsProp);
}
}
if (value >= 10) {
maxChartPoints = value;
} else {
maxChartPoints = DEFAULT_MAX_CHART_POINTS;
}
}
@Override
public List<String> getCallbackTypeIDs() {
return Collections.singletonList(StatsListener.TYPE_ID);
}
@Override
public List<Route> getRoutes() {
Route r = new Route("/train", HttpMethod.GET, FunctionType.Supplier, () -> redirect("/train/overview"));
Route r2 = new Route("/train/overview", HttpMethod.GET, FunctionType.Supplier,
() -> ok(TrainingOverview.apply(I18NProvider.getInstance())));
Route r2a = new Route("/train/overview/data", HttpMethod.GET, FunctionType.Supplier, this::getOverviewData);
Route r3 = new Route("/train/model", HttpMethod.GET, FunctionType.Supplier,
() -> ok(TrainingModel.apply(I18NProvider.getInstance())));
Route r3a = new Route("/train/model/graph", HttpMethod.GET, FunctionType.Supplier, this::getModelGraph);
Route r3b = new Route("/train/model/data/:layerId", HttpMethod.GET, FunctionType.Function, this::getModelData);
Route r4 = new Route("/train/system", HttpMethod.GET, FunctionType.Supplier,
() -> ok(TrainingSystem.apply(I18NProvider.getInstance())));
Route r4a = new Route("/train/system/data", HttpMethod.GET, FunctionType.Supplier, this::getSystemData);
Route r5 = new Route("/train/help", HttpMethod.GET, FunctionType.Supplier,
() -> ok(TrainingHelp.apply(I18NProvider.getInstance())));
Route r6 = new Route("/train/sessions/current", HttpMethod.GET, FunctionType.Supplier,
() -> ok(currentSessionID == null ? "" : currentSessionID));
Route r6a = new Route("/train/sessions/all", HttpMethod.GET, FunctionType.Supplier, this::listSessions);
Route r6b = new Route("/train/sessions/info", HttpMethod.GET, FunctionType.Supplier, this::sessionInfo);
Route r6c = new Route("/train/sessions/set/:to", HttpMethod.GET, FunctionType.Function, this::setSession);
Route r6d = new Route("/train/sessions/lastUpdate/:sessionId", HttpMethod.GET, FunctionType.Function,
this::getLastUpdateForSession);
Route r7 = new Route("/train/workers/currentByIdx", HttpMethod.GET, FunctionType.Supplier,
() -> ok(String.valueOf(currentWorkerIdx)));
Route r7a = new Route("/train/workers/setByIdx/:to", HttpMethod.GET, FunctionType.Function,
this::setWorkerByIdx);
return Arrays.asList(r, r2, r2a, r3, r3a, r3b, r4, r4a, r5, r6, r6a, r6b, r6c, r6d, r7, r7a);
}
@Override
public synchronized void reportStorageEvents(Collection<StatsStorageEvent> events) {
for (StatsStorageEvent sse : events) {
if (StatsListener.TYPE_ID.equals(sse.getTypeID())) {
if (sse.getEventType() == StatsStorageListener.EventType.PostStaticInfo
&& StatsListener.TYPE_ID.equals(sse.getTypeID())) {
knownSessionIDs.put(sse.getSessionID(), sse.getStatsStorage());
}
Long lastUpdate = lastUpdateForSession.get(sse.getSessionID());
if (lastUpdate == null) {
lastUpdateForSession.put(sse.getSessionID(), sse.getTimestamp());
} else if (sse.getTimestamp() > lastUpdate) {
lastUpdateForSession.put(sse.getSessionID(), sse.getTimestamp()); //Should be thread safe - read only elsewhere
}
}
}
if (currentSessionID == null)
getDefaultSession();
}
@Override
public synchronized void onAttach(StatsStorage statsStorage) {
for (String sessionID : statsStorage.listSessionIDs()) {
for (String typeID : statsStorage.listTypeIDsForSession(sessionID)) {
if (!StatsListener.TYPE_ID.equals(typeID))
continue;
knownSessionIDs.put(sessionID, statsStorage);
}
}
if (currentSessionID == null)
getDefaultSession();
}
@Override
public void onDetach(StatsStorage statsStorage) {
for (String s : knownSessionIDs.keySet()) {
if (knownSessionIDs.get(s) == statsStorage) {
knownSessionIDs.remove(s);
}
}
}
private void getDefaultSession() {
if (currentSessionID != null)
return;
long mostRecentTime = Long.MIN_VALUE;
String sessionID = null;
for (Map.Entry<String, StatsStorage> entry : knownSessionIDs.entrySet()) {
List<Persistable> staticInfos = entry.getValue().getAllStaticInfos(entry.getKey(), StatsListener.TYPE_ID);
if (staticInfos == null || staticInfos.size() == 0)
continue;
Persistable p = staticInfos.get(0);
long thisTime = p.getTimeStamp();
if (thisTime > mostRecentTime) {
mostRecentTime = thisTime;
sessionID = entry.getKey();
}
}
if (sessionID != null) {
currentSessionID = sessionID;
}
}
private synchronized String getWorkerIdForIndex(int workerIdx) {
String sid = currentSessionID;
if (sid == null)
return null;
Map<Integer, String> idxToId = workerIdxToName.get(sid);
if (idxToId == null) {
idxToId = Collections.synchronizedMap(new HashMap<>());
workerIdxToName.put(sid, idxToId);
}
if (idxToId.containsKey(workerIdx)) {
return idxToId.get(workerIdx);
}
//Need to record new worker...
//Get counter
AtomicInteger counter = workerIdxCount.get(sid);
if (counter == null) {
counter = new AtomicInteger(0);
workerIdxCount.put(sid, counter);
}
//Get all worker IDs
StatsStorage ss = knownSessionIDs.get(sid);
List<String> allWorkerIds = new ArrayList<>(ss.listWorkerIDsForSessionAndType(sid, StatsListener.TYPE_ID));
Collections.sort(allWorkerIds);
//Ensure all workers have been assigned an index
for (String s : allWorkerIds) {
if (idxToId.containsValue(s))
continue;
//Unknown worker ID:
idxToId.put(counter.getAndIncrement(), s);
}
//May still return null if index is wrong/too high...
return idxToId.get(workerIdx);
}
private Result listSessions() {
return Results.ok(Json.toJson(knownSessionIDs.keySet()));
}
private Result sessionInfo() {
//Display, for each session: session ID, start time, number of workers, last update
Map<String, Object> dataEachSession = new HashMap<>();
for (Map.Entry<String, StatsStorage> entry : knownSessionIDs.entrySet()) {
Map<String, Object> dataThisSession = new HashMap<>();
String sid = entry.getKey();
StatsStorage ss = entry.getValue();
List<String> workerIDs = ss.listWorkerIDsForSessionAndType(sid, StatsListener.TYPE_ID);
int workerCount = (workerIDs == null ? 0 : workerIDs.size());
List<Persistable> staticInfo = ss.getAllStaticInfos(sid, StatsListener.TYPE_ID);
long initTime = Long.MAX_VALUE;
if (staticInfo != null) {
for (Persistable p : staticInfo) {
initTime = Math.min(p.getTimeStamp(), initTime);
}
}
long lastUpdateTime = Long.MIN_VALUE;
List<Persistable> lastUpdatesAllWorkers = ss.getLatestUpdateAllWorkers(sid, StatsListener.TYPE_ID);
for (Persistable p : lastUpdatesAllWorkers) {
lastUpdateTime = Math.max(lastUpdateTime, p.getTimeStamp());
}
dataThisSession.put("numWorkers", workerCount);
dataThisSession.put("initTime", initTime == Long.MAX_VALUE ? "" : initTime);
dataThisSession.put("lastUpdate", lastUpdateTime == Long.MIN_VALUE ? "" : lastUpdateTime);
// add hashmap of workers
if (workerCount > 0) {
dataThisSession.put("workers", workerIDs);
}
//Model info: type, # layers, # params...
if (staticInfo != null && staticInfo.size() > 0) {
StatsInitializationReport sr = (StatsInitializationReport) staticInfo.get(0);
String modelClassName = sr.getModelClassName();
if (modelClassName.endsWith("MultiLayerNetwork")) {
modelClassName = "MultiLayerNetwork";
} else if (modelClassName.endsWith("ComputationGraph")) {
modelClassName = "ComputationGraph";
}
int numLayers = sr.getModelNumLayers();
long numParams = sr.getModelNumParams();
dataThisSession.put("modelType", modelClassName);
dataThisSession.put("numLayers", numLayers);
dataThisSession.put("numParams", numParams);
} else {
dataThisSession.put("modelType", "");
dataThisSession.put("numLayers", "");
dataThisSession.put("numParams", "");
}
dataEachSession.put(sid, dataThisSession);
}
return ok(Json.toJson(dataEachSession));
}
private Result setSession(String newSessionID) {
if (knownSessionIDs.containsKey(newSessionID)) {
currentSessionID = newSessionID;
currentWorkerIdx = 0;
return ok();
} else {
return Results.badRequest("Unknown session ID: " + newSessionID);
}
}
private Result getLastUpdateForSession(String sessionID) {
Long lastUpdate = lastUpdateForSession.get(sessionID);
if (lastUpdate != null)
return ok(String.valueOf(lastUpdate));
return ok("-1");
}
private Result setWorkerByIdx(String newWorkerIdx) {
try {
currentWorkerIdx = Integer.parseInt(newWorkerIdx);
} catch (NumberFormatException e) {
log.debug("Invaild call to setWorkerByIdx", e);
}
return ok();
}
private static double fixNaN(double d) {
return Double.isFinite(d) ? d : NAN_REPLACEMENT_VALUE;
}
private static void cleanLegacyIterationCounts(List<Integer> iterationCounts) {
if (iterationCounts.size() > 0) {
boolean allEqual = true;
int maxStepSize = 1;
int first = iterationCounts.get(0);
int length = iterationCounts.size();
int prevIterCount = first;
for (int i = 1; i < length; i++) {
int currIterCount = iterationCounts.get(i);
if (allEqual && first != currIterCount) {
allEqual = false;
}
maxStepSize = Math.max(maxStepSize, prevIterCount - currIterCount);
prevIterCount = currIterCount;
}
if (allEqual) {
maxStepSize = 1;
}
for (int i = 0; i < length; i++) {
iterationCounts.set(i, first + i * maxStepSize);
}
}
}
private Result getOverviewData() {
Long lastUpdate = lastUpdateForSession.get(currentSessionID);
if (lastUpdate == null)
lastUpdate = -1L;
I18N i18N = I18NProvider.getInstance();
boolean noData = currentSessionID == null;
//First pass (optimize later): query all data...
StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
String wid = getWorkerIdForIndex(currentWorkerIdx);
if (wid == null) {
noData = true;
}
List<Integer> scoresIterCount = new ArrayList<>();
List<Double> scores = new ArrayList<>();
Map<String, Object> result = new HashMap<>();
result.put("updateTimestamp", lastUpdate);
result.put("scores", scores);
result.put("scoresIter", scoresIterCount);
//Get scores info
List<Persistable> updates =
(noData ? null : ss.getAllUpdatesAfter(currentSessionID, StatsListener.TYPE_ID, wid, 0));
if (updates == null || updates.size() == 0) {
noData = true;
}
//Collect update ratios for weights
//Collect standard deviations: activations, gradients, updates
Map<String, List<Double>> updateRatios = new HashMap<>(); //Mean magnitude (updates) / mean magnitude (parameters)
result.put("updateRatios", updateRatios);
Map<String, List<Double>> stdevActivations = new HashMap<>();
Map<String, List<Double>> stdevGradients = new HashMap<>();
Map<String, List<Double>> stdevUpdates = new HashMap<>();
result.put("stdevActivations", stdevActivations);
result.put("stdevGradients", stdevGradients);
result.put("stdevUpdates", stdevUpdates);
if (!noData) {
Persistable u = updates.get(0);
if (u instanceof StatsReport) {
StatsReport sp = (StatsReport) u;
Map<String, Double> map = sp.getMeanMagnitudes(StatsType.Parameters);
if (map != null) {
for (String s : map.keySet()) {
if (!s.toLowerCase().endsWith("w"))
continue; //TODO: more robust "weights only" approach...
updateRatios.put(s, new ArrayList<>());
}
}
Map<String, Double> stdGrad = sp.getStdev(StatsType.Gradients);
if (stdGrad != null) {
for (String s : stdGrad.keySet()) {
if (!s.toLowerCase().endsWith("w"))
continue; //TODO: more robust "weights only" approach...
stdevGradients.put(s, new ArrayList<>());
}
}
Map<String, Double> stdUpdate = sp.getStdev(StatsType.Updates);
if (stdUpdate != null) {
for (String s : stdUpdate.keySet()) {
if (!s.toLowerCase().endsWith("w"))
continue; //TODO: more robust "weights only" approach...
stdevUpdates.put(s, new ArrayList<>());
}
}
Map<String, Double> stdAct = sp.getStdev(StatsType.Activations);
if (stdAct != null) {
for (String s : stdAct.keySet()) {
stdevActivations.put(s, new ArrayList<>());
}
}
}
}
StatsReport last = null;
int lastIterCount = -1;
//Legacy issue - Spark training - iteration counts are used to be reset... which means: could go 0,1,2,0,1,2, etc...
//Or, it could equally go 4,8,4,8,... or 5,5,5,5 - depending on the collection and averaging frequencies
//Now, it should use the proper iteration counts
boolean needToHandleLegacyIterCounts = false;
if (!noData) {
double lastScore;
int totalUpdates = updates.size();
int subsamplingFrequency = 1;
if (totalUpdates > maxChartPoints) {
subsamplingFrequency = totalUpdates / maxChartPoints;
}
int pCount = -1;
int lastUpdateIdx = updates.size() - 1;
for (Persistable u : updates) {
pCount++;
if (!(u instanceof StatsReport))
continue;
last = (StatsReport) u;
int iterCount = last.getIterationCount();
if (iterCount <= lastIterCount) {
needToHandleLegacyIterCounts = true;
}
lastIterCount = iterCount;
if (pCount > 0 && subsamplingFrequency > 1 && pCount % subsamplingFrequency != 0) {
//Skip this - subsample the data
if (pCount != lastUpdateIdx)
continue; //Always keep the most recent value
}
scoresIterCount.add(iterCount);
lastScore = last.getScore();
if (Double.isFinite(lastScore)) {
scores.add(lastScore);
} else {
scores.add(NAN_REPLACEMENT_VALUE);
}
//Update ratios: mean magnitudes(updates) / mean magnitudes (parameters)
Map<String, Double> updateMM = last.getMeanMagnitudes(StatsType.Updates);
Map<String, Double> paramMM = last.getMeanMagnitudes(StatsType.Parameters);
if (updateMM != null && paramMM != null && updateMM.size() > 0 && paramMM.size() > 0) {
for (String s : updateRatios.keySet()) {
List<Double> ratioHistory = updateRatios.get(s);
double currUpdate = updateMM.getOrDefault(s, 0.0);
double currParam = paramMM.getOrDefault(s, 0.0);
double ratio = currUpdate / currParam;
if (Double.isFinite(ratio)) {
ratioHistory.add(ratio);
} else {
ratioHistory.add(NAN_REPLACEMENT_VALUE);
}
}
}
//Standard deviations: gradients, updates, activations
Map<String, Double> stdGrad = last.getStdev(StatsType.Gradients);
Map<String, Double> stdUpd = last.getStdev(StatsType.Updates);
Map<String, Double> stdAct = last.getStdev(StatsType.Activations);
if (stdGrad != null) {
for (String s : stdevGradients.keySet()) {
double d = stdGrad.getOrDefault(s, 0.0);
stdevGradients.get(s).add(fixNaN(d));
}
}
if (stdUpd != null) {
for (String s : stdevUpdates.keySet()) {
double d = stdUpd.getOrDefault(s, 0.0);
stdevUpdates.get(s).add(fixNaN(d));
}
}
if (stdAct != null) {
for (String s : stdevActivations.keySet()) {
double d = stdAct.getOrDefault(s, 0.0);
stdevActivations.get(s).add(fixNaN(d));
}
}
}
}
if (needToHandleLegacyIterCounts) {
cleanLegacyIterationCounts(scoresIterCount);
}
//----- Performance Info -----
String[][] perfInfo = new String[][] {{i18N.getMessage("train.overview.perftable.startTime"), ""},
{i18N.getMessage("train.overview.perftable.totalRuntime"), ""},
{i18N.getMessage("train.overview.perftable.lastUpdate"), ""},
{i18N.getMessage("train.overview.perftable.totalParamUpdates"), ""},
{i18N.getMessage("train.overview.perftable.updatesPerSec"), ""},
{i18N.getMessage("train.overview.perftable.examplesPerSec"), ""}};
if (last != null) {
perfInfo[2][1] = String.valueOf(dateFormat.format(new Date(last.getTimeStamp())));
perfInfo[3][1] = String.valueOf(last.getTotalMinibatches());
perfInfo[4][1] = String.valueOf(df2.format(last.getMinibatchesPerSecond()));
perfInfo[5][1] = String.valueOf(df2.format(last.getExamplesPerSecond()));
}
result.put("perf", perfInfo);
// ----- Model Info -----
String[][] modelInfo = new String[][] {{i18N.getMessage("train.overview.modeltable.modeltype"), ""},
{i18N.getMessage("train.overview.modeltable.nLayers"), ""},
{i18N.getMessage("train.overview.modeltable.nParams"), ""}};
if (!noData) {
Persistable p = ss.getStaticInfo(currentSessionID, StatsListener.TYPE_ID, wid);
if (p != null) {
StatsInitializationReport initReport = (StatsInitializationReport) p;
int nLayers = initReport.getModelNumLayers();
long numParams = initReport.getModelNumParams();
String className = initReport.getModelClassName();
String modelType;
if (className.endsWith("MultiLayerNetwork")) {
modelType = "MultiLayerNetwork";
} else if (className.endsWith("ComputationGraph")) {
modelType = "ComputationGraph";
} else {
modelType = className;
if (modelType.lastIndexOf('.') > 0) {
modelType = modelType.substring(modelType.lastIndexOf('.') + 1);
}
}
modelInfo[0][1] = modelType;
modelInfo[1][1] = String.valueOf(nLayers);
modelInfo[2][1] = String.valueOf(numParams);
}
}
result.put("model", modelInfo);
return Results.ok(Json.toJson(result));
}
private Result getModelGraph() {
boolean noData = currentSessionID == null;
StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
List<Persistable> allStatic = (noData ? Collections.EMPTY_LIST
: ss.getAllStaticInfos(currentSessionID, StatsListener.TYPE_ID));
if (allStatic.size() == 0) {
return ok();
}
TrainModuleUtils.GraphInfo gi = getGraphInfo();
if (gi == null)
return ok();
return ok(Json.toJson(gi));
}
private TrainModuleUtils.GraphInfo getGraphInfo() {
Triple<MultiLayerConfiguration, ComputationGraphConfiguration, NeuralNetConfiguration> conf = getConfig();
if (conf == null) {
return null;
}
if (conf.getFirst() != null) {
return TrainModuleUtils.buildGraphInfo(conf.getFirst());
} else if (conf.getSecond() != null) {
return TrainModuleUtils.buildGraphInfo(conf.getSecond());
} else if (conf.getThird() != null) {
return TrainModuleUtils.buildGraphInfo(conf.getThird());
} else {
return null;
}
}
private Triple<MultiLayerConfiguration, ComputationGraphConfiguration, NeuralNetConfiguration> getConfig() {
boolean noData = currentSessionID == null;
StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
List<Persistable> allStatic = (noData ? Collections.EMPTY_LIST
: ss.getAllStaticInfos(currentSessionID, StatsListener.TYPE_ID));
if (allStatic.size() == 0)
return null;
StatsInitializationReport p = (StatsInitializationReport) allStatic.get(0);
String modelClass = p.getModelClassName();
String config = p.getModelConfigJson();
if (modelClass.endsWith("MultiLayerNetwork")) {
MultiLayerConfiguration conf = MultiLayerConfiguration.fromJson(config);
return new Triple<>(conf, null, null);
} else if (modelClass.endsWith("ComputationGraph")) {
ComputationGraphConfiguration conf = ComputationGraphConfiguration.fromJson(config);
return new Triple<>(null, conf, null);
} else {
try {
NeuralNetConfiguration layer =
NeuralNetConfiguration.mapper().readValue(config, NeuralNetConfiguration.class);
return new Triple<>(null, null, layer);
} catch (Exception e) {
e.printStackTrace();
}
}
return null;
}
private Result getModelData(String str) {
Long lastUpdateTime = lastUpdateForSession.get(currentSessionID);
if (lastUpdateTime == null)
lastUpdateTime = -1L;
int layerIdx = Integer.parseInt(str); //TODO validation
I18N i18N = I18NProvider.getInstance();
//Model info for layer
boolean noData = currentSessionID == null;
//First pass (optimize later): query all data...
StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
String wid = getWorkerIdForIndex(currentWorkerIdx);
if (wid == null) {
noData = true;
}
Map<String, Object> result = new HashMap<>();
result.put("updateTimestamp", lastUpdateTime);
Triple<MultiLayerConfiguration, ComputationGraphConfiguration, NeuralNetConfiguration> conf = getConfig();
if (conf == null) {
return ok(Json.toJson(result));
}
TrainModuleUtils.GraphInfo gi = getGraphInfo();
if (gi == null) {
return ok(Json.toJson(result));
}
// Get static layer info
String[][] layerInfoTable = getLayerInfoTable(layerIdx, gi, i18N, noData, ss, wid);
result.put("layerInfo", layerInfoTable);
//First: get all data, and subsample it if necessary, to avoid returning too many points...
List<Persistable> updates =
(noData ? null : ss.getAllUpdatesAfter(currentSessionID, StatsListener.TYPE_ID, wid, 0));
List<Integer> iterationCounts = null;
boolean needToHandleLegacyIterCounts = false;
if (updates != null && updates.size() > maxChartPoints) {
int subsamplingFrequency = updates.size() / maxChartPoints;
List<Persistable> subsampled = new ArrayList<>();
iterationCounts = new ArrayList<>();
int pCount = -1;
int lastUpdateIdx = updates.size() - 1;
int lastIterCount = -1;
for (Persistable p : updates) {
if (!(p instanceof StatsReport))
continue;;
StatsReport sr = (StatsReport) p;
pCount++;
int iterCount = sr.getIterationCount();
if (iterCount <= lastIterCount) {
needToHandleLegacyIterCounts = true;
}
lastIterCount = iterCount;
if (pCount > 0 && subsamplingFrequency > 1 && pCount % subsamplingFrequency != 0) {
//Skip this to subsample the data
if (pCount != lastUpdateIdx)
continue; //Always keep the most recent value
}
subsampled.add(p);
iterationCounts.add(iterCount);
}
updates = subsampled;
} else if (updates != null) {
int offset = 0;
iterationCounts = new ArrayList<>(updates.size());
int lastIterCount = -1;
for (Persistable p : updates) {
if (!(p instanceof StatsReport))
continue;;
StatsReport sr = (StatsReport) p;
int iterCount = sr.getIterationCount();
if (iterCount <= lastIterCount) {
needToHandleLegacyIterCounts = true;
}
iterationCounts.add(iterCount);
}
}
//Legacy issue - Spark training - iteration counts are used to be reset... which means: could go 0,1,2,0,1,2, etc...
//Or, it could equally go 4,8,4,8,... or 5,5,5,5 - depending on the collection and averaging frequencies
//Now, it should use the proper iteration counts
if (needToHandleLegacyIterCounts) {
cleanLegacyIterationCounts(iterationCounts);
}
//Get mean magnitudes line chart
ModelType mt;
if (conf.getFirst() != null)
mt = ModelType.MLN;
else if (conf.getSecond() != null)
mt = ModelType.CG;
else
mt = ModelType.Layer;
MeanMagnitudes mm = getLayerMeanMagnitudes(layerIdx, gi, updates, iterationCounts, mt);
Map<String, Object> mmRatioMap = new HashMap<>();
mmRatioMap.put("layerParamNames", mm.getRatios().keySet());
mmRatioMap.put("iterCounts", mm.getIterations());
mmRatioMap.put("ratios", mm.getRatios());
mmRatioMap.put("paramMM", mm.getParamMM());
mmRatioMap.put("updateMM", mm.getUpdateMM());
result.put("meanMag", mmRatioMap);
//Get activations line chart for layer
Triple<int[], float[], float[]> activationsData = getLayerActivations(layerIdx, gi, updates, iterationCounts);
Map<String, Object> activationMap = new HashMap<>();
activationMap.put("iterCount", activationsData.getFirst());
activationMap.put("mean", activationsData.getSecond());
activationMap.put("stdev", activationsData.getThird());
result.put("activations", activationMap);
//Get learning rate vs. time chart for layer
Map<String, Object> lrs = getLayerLearningRates(layerIdx, gi, updates, iterationCounts, mt);
result.put("learningRates", lrs);
//Parameters histogram data
Persistable lastUpdate = (updates != null && updates.size() > 0 ? updates.get(updates.size() - 1) : null);
Map<String, Object> paramHistograms = getHistograms(layerIdx, gi, StatsType.Parameters, lastUpdate);
result.put("paramHist", paramHistograms);
//Updates histogram data
Map<String, Object> updateHistograms = getHistograms(layerIdx, gi, StatsType.Updates, lastUpdate);
result.put("updateHist", updateHistograms);
return ok(Json.toJson(result));
}
public Result getSystemData() {
Long lastUpdate = lastUpdateForSession.get(currentSessionID);
if (lastUpdate == null)
lastUpdate = -1L;
I18N i18n = I18NProvider.getInstance();
//First: get the MOST RECENT update...
//Then get all updates from most recent - 5 minutes -> TODO make this configurable...
boolean noData = currentSessionID == null;
StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
List<Persistable> allStatic = (noData ? Collections.EMPTY_LIST
: ss.getAllStaticInfos(currentSessionID, StatsListener.TYPE_ID));
List<Persistable> latestUpdates = (noData ? Collections.EMPTY_LIST
: ss.getLatestUpdateAllWorkers(currentSessionID, StatsListener.TYPE_ID));
long lastUpdateTime = -1;
if (latestUpdates == null || latestUpdates.size() == 0) {
noData = true;
} else {
for (Persistable p : latestUpdates) {
lastUpdateTime = Math.max(lastUpdateTime, p.getTimeStamp());
}
}
long fromTime = lastUpdateTime - 5 * 60 * 1000; //TODO Make configurable
List<Persistable> lastNMinutes =
(noData ? null : ss.getAllUpdatesAfter(currentSessionID, StatsListener.TYPE_ID, fromTime));
Map<String, Object> mem = getMemory(allStatic, lastNMinutes, i18n);
Pair<Map<String, Object>, Map<String, Object>> hwSwInfo = getHardwareSoftwareInfo(allStatic, i18n);
Map<String, Object> ret = new HashMap<>();
ret.put("updateTimestamp", lastUpdate);
ret.put("memory", mem);
ret.put("hardware", hwSwInfo.getFirst());
ret.put("software", hwSwInfo.getSecond());
return ok(Json.toJson(ret));
}
private static String getLayerType(Layer layer) {
String layerType = "n/a";
if (layer != null) {
try {
layerType = layer.getClass().getSimpleName().replaceAll("Layer$", "");
} catch (Exception e) {
}
}
return layerType;
}
private String[][] getLayerInfoTable(int layerIdx, TrainModuleUtils.GraphInfo gi, I18N i18N, boolean noData,
StatsStorage ss, String wid) {
List<String[]> layerInfoRows = new ArrayList<>();
layerInfoRows.add(new String[] {i18N.getMessage("train.model.layerinfotable.layerName"),
gi.getVertexNames().get(layerIdx)});
layerInfoRows.add(new String[] {i18N.getMessage("train.model.layerinfotable.layerType"), ""});
if (!noData) {
Persistable p = ss.getStaticInfo(currentSessionID, StatsListener.TYPE_ID, wid);
if (p != null) {
StatsInitializationReport initReport = (StatsInitializationReport) p;
String configJson = initReport.getModelConfigJson();
String modelClass = initReport.getModelClassName();
//TODO error handling...
String layerType = "";
Layer layer = null;
NeuralNetConfiguration nnc = null;
if (modelClass.endsWith("MultiLayerNetwork")) {
MultiLayerConfiguration conf = MultiLayerConfiguration.fromJson(configJson);
int confIdx = layerIdx - 1; //-1 because of input
if (confIdx >= 0) {
nnc = conf.getConf(confIdx);
layer = nnc.getLayer();
} else {
//Input layer
layerType = "Input";
}
} else if (modelClass.endsWith("ComputationGraph")) {
ComputationGraphConfiguration conf = ComputationGraphConfiguration.fromJson(configJson);
String vertexName = gi.getVertexNames().get(layerIdx);
Map<String, GraphVertex> vertices = conf.getVertices();
if (vertices.containsKey(vertexName) && vertices.get(vertexName) instanceof LayerVertex) {
LayerVertex lv = (LayerVertex) vertices.get(vertexName);
nnc = lv.getLayerConf();
layer = nnc.getLayer();
} else if (conf.getNetworkInputs().contains(vertexName)) {
layerType = "Input";
} else {
GraphVertex gv = conf.getVertices().get(vertexName);
if (gv != null) {
layerType = gv.getClass().getSimpleName();
}
}
} else if (modelClass.endsWith("VariationalAutoencoder")) {
layerType = gi.getVertexTypes().get(layerIdx);
Map<String, String> map = gi.getVertexInfo().get(layerIdx);
for (Map.Entry<String, String> entry : map.entrySet()) {
layerInfoRows.add(new String[] {entry.getKey(), entry.getValue()});
}
}
if (layer != null) {
layerType = getLayerType(layer);
}
if (layer != null) {
String activationFn = null;
if (layer instanceof FeedForwardLayer) {
FeedForwardLayer ffl = (FeedForwardLayer) layer;
layerInfoRows.add(new String[] {i18N.getMessage("train.model.layerinfotable.layerNIn"),
String.valueOf(ffl.getNIn())});
layerInfoRows.add(new String[] {i18N.getMessage("train.model.layerinfotable.layerSize"),
String.valueOf(ffl.getNOut())});
activationFn = layer.getActivationFn().toString();
}
int nParams = layer.initializer().numParams(nnc);
layerInfoRows.add(new String[] {i18N.getMessage("train.model.layerinfotable.layerNParams"),
String.valueOf(nParams)});
if (nParams > 0) {
WeightInit wi = layer.getWeightInit();
String str = wi.toString();
if (wi == WeightInit.DISTRIBUTION) {
str += layer.getDist();
}
layerInfoRows.add(new String[] {i18N.getMessage("train.model.layerinfotable.layerWeightInit"),
str});
IUpdater u = layer.getIUpdater();
String us = (u == null ? "" : u.getClass().getSimpleName());
layerInfoRows.add(
new String[] {i18N.getMessage("train.model.layerinfotable.layerUpdater"), us});
//TODO: Maybe L1/L2, dropout, updater-specific values etc
}
if (layer instanceof ConvolutionLayer || layer instanceof SubsamplingLayer) {
int[] kernel;
int[] stride;
int[] padding;
if (layer instanceof ConvolutionLayer) {
ConvolutionLayer cl = (ConvolutionLayer) layer;
kernel = cl.getKernelSize();
stride = cl.getStride();
padding = cl.getPadding();
} else {
SubsamplingLayer ssl = (SubsamplingLayer) layer;
kernel = ssl.getKernelSize();
stride = ssl.getStride();
padding = ssl.getPadding();
activationFn = null;
layerInfoRows.add(new String[] {
i18N.getMessage("train.model.layerinfotable.layerSubsamplingPoolingType"),
ssl.getPoolingType().toString()});
}
layerInfoRows.add(new String[] {i18N.getMessage("train.model.layerinfotable.layerCnnKernel"),
Arrays.toString(kernel)});
layerInfoRows.add(new String[] {i18N.getMessage("train.model.layerinfotable.layerCnnStride"),
Arrays.toString(stride)});
layerInfoRows.add(new String[] {i18N.getMessage("train.model.layerinfotable.layerCnnPadding"),
Arrays.toString(padding)});
}
if (activationFn != null) {
layerInfoRows.add(new String[] {i18N.getMessage("train.model.layerinfotable.layerActivationFn"),
activationFn});
}
}
layerInfoRows.get(1)[1] = layerType;
}
}
return layerInfoRows.toArray(new String[layerInfoRows.size()][0]);
}
//TODO float precision for smaller transfers?
//First: iteration. Second: ratios, by parameter
private MeanMagnitudes getLayerMeanMagnitudes(int layerIdx, TrainModuleUtils.GraphInfo gi,
List<Persistable> updates, List<Integer> iterationCounts, ModelType modelType) {
if (gi == null) {
return new MeanMagnitudes(Collections.emptyList(), Collections.emptyMap(), Collections.emptyMap(),
Collections.emptyMap());
}
String layerName = gi.getVertexNames().get(layerIdx);
if (modelType != ModelType.CG) {
//Get the original name, for the index...
layerName = gi.getOriginalVertexName().get(layerIdx);
}
String layerType = gi.getVertexTypes().get(layerIdx);
if ("input".equalsIgnoreCase(layerType)) { //TODO better checking - other vertices, etc
return new MeanMagnitudes(Collections.emptyList(), Collections.emptyMap(), Collections.emptyMap(),
Collections.emptyMap());
}
List<Integer> iterCounts = new ArrayList<>();
Map<String, List<Double>> ratioValues = new HashMap<>();
Map<String, List<Double>> outParamMM = new HashMap<>();
Map<String, List<Double>> outUpdateMM = new HashMap<>();
if (updates != null) {
int pCount = -1;
for (Persistable u : updates) {
pCount++;
if (!(u instanceof StatsReport))
continue;
StatsReport sp = (StatsReport) u;
if (iterationCounts != null) {
iterCounts.add(iterationCounts.get(pCount));
} else {
int iterCount = sp.getIterationCount();
iterCounts.add(iterCount);
}
//Info we want, for each parameter in this layer: mean magnitudes for parameters, updates AND the ratio of these
Map<String, Double> paramMM = sp.getMeanMagnitudes(StatsType.Parameters);
Map<String, Double> updateMM = sp.getMeanMagnitudes(StatsType.Updates);
for (String s : paramMM.keySet()) {
String prefix;
if (modelType == ModelType.Layer) {
prefix = layerName;
} else {
prefix = layerName + "_";
}
if (s.startsWith(prefix)) {
//Relevant parameter for this layer...
String layerParam = s.substring(prefix.length());
double pmm = paramMM.getOrDefault(s, 0.0);
double umm = updateMM.getOrDefault(s, 0.0);
if (!Double.isFinite(pmm)) {
pmm = NAN_REPLACEMENT_VALUE;
}
if (!Double.isFinite(umm)) {
umm = NAN_REPLACEMENT_VALUE;
}
double ratio;
if (umm == 0.0 && pmm == 0.0) {
ratio = 0.0; //To avoid NaN from 0/0
} else {
ratio = umm / pmm;
}
List<Double> list = ratioValues.get(layerParam);
if (list == null) {
list = new ArrayList<>();
ratioValues.put(layerParam, list);
}
list.add(ratio);
List<Double> pmmList = outParamMM.get(layerParam);
if (pmmList == null) {
pmmList = new ArrayList<>();
outParamMM.put(layerParam, pmmList);
}
pmmList.add(pmm);
List<Double> ummList = outUpdateMM.get(layerParam);
if (ummList == null) {
ummList = new ArrayList<>();
outUpdateMM.put(layerParam, ummList);
}
ummList.add(umm);
}
}
}
}
return new MeanMagnitudes(iterCounts, ratioValues, outParamMM, outUpdateMM);
}
private static Triple<int[], float[], float[]> EMPTY_TRIPLE = new Triple<>(new int[0], new float[0], new float[0]);
private Triple<int[], float[], float[]> getLayerActivations(int index, TrainModuleUtils.GraphInfo gi,
List<Persistable> updates, List<Integer> iterationCounts) {
if (gi == null) {
return EMPTY_TRIPLE;
}
String type = gi.getVertexTypes().get(index); //Index may be for an input, for example
if ("input".equalsIgnoreCase(type)) {
return EMPTY_TRIPLE;
}
List<String> origNames = gi.getOriginalVertexName();
if (index < 0 || index >= origNames.size()) {
return EMPTY_TRIPLE;
}
String layerName = origNames.get(index);
int size = (updates == null ? 0 : updates.size());
int[] iterCounts = new int[size];
float[] mean = new float[size];
float[] stdev = new float[size];
int used = 0;
if (updates != null) {
int uCount = -1;
for (Persistable u : updates) {
uCount++;
if (!(u instanceof StatsReport))
continue;
StatsReport sp = (StatsReport) u;
if (iterationCounts == null) {
iterCounts[used] = sp.getIterationCount();
} else {
iterCounts[used] = iterationCounts.get(uCount);
}
Map<String, Double> means = sp.getMean(StatsType.Activations);
Map<String, Double> stdevs = sp.getStdev(StatsType.Activations);
//TODO PROPER VALIDATION ETC, ERROR HANDLING
if (means != null && means.containsKey(layerName)) {
mean[used] = means.get(layerName).floatValue();
stdev[used] = stdevs.get(layerName).floatValue();
if (!Float.isFinite(mean[used])) {
mean[used] = (float) NAN_REPLACEMENT_VALUE;
}
if (!Float.isFinite(stdev[used])) {
stdev[used] = (float) NAN_REPLACEMENT_VALUE;
}
used++;
}
}
}
if (used != iterCounts.length) {
iterCounts = Arrays.copyOf(iterCounts, used);
mean = Arrays.copyOf(mean, used);
stdev = Arrays.copyOf(stdev, used);
}
return new Triple<>(iterCounts, mean, stdev);
}
private static final Map<String, Object> EMPTY_LR_MAP = new HashMap<>();
static {
EMPTY_LR_MAP.put("iterCounts", new int[0]);
EMPTY_LR_MAP.put("paramNames", Collections.EMPTY_LIST);
EMPTY_LR_MAP.put("lrs", Collections.EMPTY_MAP);
}
private Map<String, Object> getLayerLearningRates(int layerIdx, TrainModuleUtils.GraphInfo gi,
List<Persistable> updates, List<Integer> iterationCounts, ModelType modelType) {
if (gi == null) {
return Collections.emptyMap();
}
List<String> origNames = gi.getOriginalVertexName();
String type = gi.getVertexTypes().get(layerIdx); //Index may be for an input, for example
if ("input".equalsIgnoreCase(type)) {
return EMPTY_LR_MAP;
}
if (layerIdx < 0 || layerIdx >= origNames.size()) {
return EMPTY_LR_MAP;
}
String layerName = gi.getOriginalVertexName().get(layerIdx);
int size = (updates == null ? 0 : updates.size());
int[] iterCounts = new int[size];
Map<String, float[]> byName = new HashMap<>();
int used = 0;
if (updates != null) {
int uCount = -1;
for (Persistable u : updates) {
uCount++;
if (!(u instanceof StatsReport))
continue;
StatsReport sp = (StatsReport) u;
if (iterationCounts == null) {
iterCounts[used] = sp.getIterationCount();
} else {
iterCounts[used] = iterationCounts.get(uCount);
}
//TODO PROPER VALIDATION ETC, ERROR HANDLING
Map<String, Double> lrs = sp.getLearningRates();
String prefix;
if (modelType == ModelType.Layer) {
prefix = layerName;
} else {
prefix = layerName + "_";
}
for (String p : lrs.keySet()) {
if (p.startsWith(prefix)) {
String layerParamName = p.substring(Math.min(p.length(), prefix.length()));
if (!byName.containsKey(layerParamName)) {
byName.put(layerParamName, new float[size]);
}
float[] lrThisParam = byName.get(layerParamName);
lrThisParam[used] = lrs.get(p).floatValue();
}
}
used++;
}
}
List<String> paramNames = new ArrayList<>(byName.keySet());
Collections.sort(paramNames); //Sorted for consistency
Map<String, Object> ret = new HashMap<>();
ret.put("iterCounts", iterCounts);
ret.put("paramNames", paramNames);
ret.put("lrs", byName);
return ret;
}
private static Map<String, Object> getHistograms(int layerIdx, TrainModuleUtils.GraphInfo gi, StatsType statsType,
Persistable p) {
if (p == null)
return null;
if (!(p instanceof StatsReport))
return null;
StatsReport sr = (StatsReport) p;
String layerName = gi.getOriginalVertexName().get(layerIdx);
Map<String, Histogram> map = sr.getHistograms(statsType);
List<String> paramNames = new ArrayList<>();
Map<String, Object> ret = new HashMap<>();
if (layerName != null) {
for (String s : map.keySet()) {
if (s.startsWith(layerName)) {
String paramName;
if (s.charAt(layerName.length()) == '_') {
//MLN or CG parameter naming convention
paramName = s.substring(layerName.length() + 1);
} else {
//Pretrain layer (VAE, RBM) naming convention
paramName = s.substring(layerName.length());
}
paramNames.add(paramName);
Histogram h = map.get(s);
Map<String, Object> thisHist = new HashMap<>();
double min = h.getMin();
double max = h.getMax();
if (Double.isNaN(min)) {
//If either is NaN, both will be
min = NAN_REPLACEMENT_VALUE;
max = NAN_REPLACEMENT_VALUE;
}
thisHist.put("min", min);
thisHist.put("max", max);
thisHist.put("bins", h.getNBins());
thisHist.put("counts", h.getBinCounts());
ret.put(paramName, thisHist);
}
}
}
ret.put("paramNames", paramNames);
return ret;
}
private static Map<String, Object> getMemory(List<Persistable> staticInfoAllWorkers,
List<Persistable> updatesLastNMinutes, I18N i18n) {
Map<String, Object> ret = new HashMap<>();
//First: map workers to JVMs
Set<String> jvmIDs = new HashSet<>();
Map<String, String> workersToJvms = new HashMap<>();
Map<String, Integer> workerNumDevices = new HashMap<>();
Map<String, String[]> deviceNames = new HashMap<>();
for (Persistable p : staticInfoAllWorkers) {
//TODO validation/checks
StatsInitializationReport init = (StatsInitializationReport) p;
String jvmuid = init.getSwJvmUID();
workersToJvms.put(p.getWorkerID(), jvmuid);
jvmIDs.add(jvmuid);
int nDevices = init.getHwNumDevices();
workerNumDevices.put(p.getWorkerID(), nDevices);
if (nDevices > 0) {
String[] deviceNamesArr = init.getHwDeviceDescription();
deviceNames.put(p.getWorkerID(), deviceNamesArr);
}
}
List<String> jvmList = new ArrayList<>(jvmIDs);
Collections.sort(jvmList);
//For each unique JVM, collect memory info
//Do this by selecting the first worker
int count = 0;
for (String jvm : jvmList) {
List<String> workersForJvm = new ArrayList<>();
for (String s : workersToJvms.keySet()) {
if (workersToJvms.get(s).equals(jvm)) {
workersForJvm.add(s);
}
}
Collections.sort(workersForJvm);
String wid = workersForJvm.get(0);
int numDevices = workerNumDevices.get(wid);
Map<String, Object> jvmData = new HashMap<>();
List<Long> timestamps = new ArrayList<>();
List<Float> fracJvm = new ArrayList<>();
List<Float> fracOffHeap = new ArrayList<>();
long[] lastBytes = new long[2 + numDevices];
long[] lastMaxBytes = new long[2 + numDevices];
List<List<Float>> fracDeviceMem = null;
if (numDevices > 0) {
fracDeviceMem = new ArrayList<>(numDevices);
for (int i = 0; i < numDevices; i++) {
fracDeviceMem.add(new ArrayList<>());
}
}
for (Persistable p : updatesLastNMinutes) {
//TODO single pass
if (!p.getWorkerID().equals(wid))
continue;
if (!(p instanceof StatsReport))
continue;
StatsReport sp = (StatsReport) p;
timestamps.add(sp.getTimeStamp());
long jvmCurrentBytes = sp.getJvmCurrentBytes();
long jvmMaxBytes = sp.getJvmMaxBytes();
long ohCurrentBytes = sp.getOffHeapCurrentBytes();
long ohMaxBytes = sp.getOffHeapMaxBytes();
double jvmFrac = jvmCurrentBytes / ((double) jvmMaxBytes);
double offheapFrac = ohCurrentBytes / ((double) ohMaxBytes);
if (Double.isNaN(jvmFrac))
jvmFrac = 0.0;
if (Double.isNaN(offheapFrac))
offheapFrac = 0.0;
fracJvm.add((float) jvmFrac);
fracOffHeap.add((float) offheapFrac);
lastBytes[0] = jvmCurrentBytes;
lastBytes[1] = ohCurrentBytes;
lastMaxBytes[0] = jvmMaxBytes;
lastMaxBytes[1] = ohMaxBytes;
if (numDevices > 0) {
long[] devBytes = sp.getDeviceCurrentBytes();
long[] devMaxBytes = sp.getDeviceMaxBytes();
for (int i = 0; i < numDevices; i++) {
double frac = devBytes[i] / ((double) devMaxBytes[i]);
if (Double.isNaN(frac))
frac = 0.0;
fracDeviceMem.get(i).add((float) frac);
lastBytes[2 + i] = devBytes[i];
lastMaxBytes[2 + i] = devMaxBytes[i];
}
}
}
List<List<Float>> fracUtilized = new ArrayList<>();
fracUtilized.add(fracJvm);
fracUtilized.add(fracOffHeap);
String[] seriesNames = new String[2 + numDevices];
seriesNames[0] = i18n.getMessage("train.system.hwTable.jvmCurrent");
seriesNames[1] = i18n.getMessage("train.system.hwTable.offHeapCurrent");
boolean[] isDevice = new boolean[2 + numDevices];
String[] devNames = deviceNames.get(wid);
for (int i = 0; i < numDevices; i++) {
seriesNames[2 + i] = devNames != null && devNames.length > i ? devNames[i] : "";
fracUtilized.add(fracDeviceMem.get(i));
isDevice[2 + i] = true;
}
jvmData.put("times", timestamps);
jvmData.put("isDevice", isDevice);
jvmData.put("seriesNames", seriesNames);
jvmData.put("values", fracUtilized);
jvmData.put("currentBytes", lastBytes);
jvmData.put("maxBytes", lastMaxBytes);
ret.put(String.valueOf(count), jvmData);
count++;
}
return ret;
}
private static Pair<Map<String, Object>, Map<String, Object>> getHardwareSoftwareInfo(
List<Persistable> staticInfoAllWorkers, I18N i18n) {
Map<String, Object> retHw = new HashMap<>();
Map<String, Object> retSw = new HashMap<>();
//First: map workers to JVMs
Set<String> jvmIDs = new HashSet<>();
Map<String, StatsInitializationReport> staticByJvm = new HashMap<>();
for (Persistable p : staticInfoAllWorkers) {
//TODO validation/checks
StatsInitializationReport init = (StatsInitializationReport) p;
String jvmuid = init.getSwJvmUID();
jvmIDs.add(jvmuid);
staticByJvm.put(jvmuid, init);
}
List<String> jvmList = new ArrayList<>(jvmIDs);
Collections.sort(jvmList);
//For each unique JVM, collect hardware info
int count = 0;
for (String jvm : jvmList) {
StatsInitializationReport sr = staticByJvm.get(jvm);
//---- Harware Info ----
List<String[]> hwInfo = new ArrayList<>();
int numDevices = sr.getHwNumDevices();
String[] deviceDescription = sr.getHwDeviceDescription();
long[] devTotalMem = sr.getHwDeviceTotalMemory();
hwInfo.add(new String[] {i18n.getMessage("train.system.hwTable.jvmMax"),
String.valueOf(sr.getHwJvmMaxMemory())});
hwInfo.add(new String[] {i18n.getMessage("train.system.hwTable.offHeapMax"),
String.valueOf(sr.getHwOffHeapMaxMemory())});
hwInfo.add(new String[] {i18n.getMessage("train.system.hwTable.jvmProcs"),
String.valueOf(sr.getHwJvmAvailableProcessors())});
hwInfo.add(new String[] {i18n.getMessage("train.system.hwTable.computeDevices"),
String.valueOf(numDevices)});
for (int i = 0; i < numDevices; i++) {
String label = i18n.getMessage("train.system.hwTable.deviceName") + " (" + i + ")";
String name = (deviceDescription == null || i >= deviceDescription.length ? String.valueOf(i)
: deviceDescription[i]);
hwInfo.add(new String[] {label, name});
String memLabel = i18n.getMessage("train.system.hwTable.deviceMemory") + " (" + i + ")";
String memBytes =
(devTotalMem == null | i >= devTotalMem.length ? "-" : String.valueOf(devTotalMem[i]));
hwInfo.add(new String[] {memLabel, memBytes});
}
retHw.put(String.valueOf(count), hwInfo);
//---- Software Info -----
String nd4jBackend = sr.getSwNd4jBackendClass();
if (nd4jBackend != null && nd4jBackend.contains(".")) {
int idx = nd4jBackend.lastIndexOf('.');
nd4jBackend = nd4jBackend.substring(idx + 1);
String temp;
switch (nd4jBackend) {
case "CpuNDArrayFactory":
temp = "CPU";
break;
case "JCublasNDArrayFactory":
temp = "CUDA";
break;
default:
temp = nd4jBackend;
}
nd4jBackend = temp;
}
String datatype = sr.getSwNd4jDataTypeName();
if (datatype == null)
datatype = "";
else
datatype = datatype.toLowerCase();
List<String[]> swInfo = new ArrayList<>();
swInfo.add(new String[] {i18n.getMessage("train.system.swTable.os"), sr.getSwOsName()});
swInfo.add(new String[] {i18n.getMessage("train.system.swTable.hostname"), sr.getSwHostName()});
swInfo.add(new String[] {i18n.getMessage("train.system.swTable.osArch"), sr.getSwArch()});
swInfo.add(new String[] {i18n.getMessage("train.system.swTable.jvmName"), sr.getSwJvmName()});
swInfo.add(new String[] {i18n.getMessage("train.system.swTable.jvmVersion"), sr.getSwJvmVersion()});
swInfo.add(new String[] {i18n.getMessage("train.system.swTable.nd4jBackend"), nd4jBackend});
swInfo.add(new String[] {i18n.getMessage("train.system.swTable.nd4jDataType"), datatype});
retSw.put(String.valueOf(count), swInfo);
count++;
}
return new Pair<>(retHw, retSw);
}
@AllArgsConstructor
@Data
private static class MeanMagnitudes {
private List<Integer> iterations;
private Map<String, List<Double>> ratios;
private Map<String, List<Double>> paramMM;
private Map<String, List<Double>> updateMM;
}
}