package org.deeplearning4j.examples.userInterface.util;
import javafx.animation.AnimationTimer;
import javafx.application.Application;
import javafx.scene.Group;
import javafx.scene.Scene;
import javafx.scene.paint.Color;
import javafx.scene.paint.PhongMaterial;
import javafx.scene.shape.Sphere;
import javafx.stage.Modality;
import javafx.stage.Stage;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
/**
* Created by Don Smith on 3/27/2017.
*/
public class ActivationsViewer extends Application {
private static int WIDTH = 1400;
private static int HEIGHT = 900;
public static ActivationsViewer staticInstance; // output
public static MultiLayerNetwork network; //input
private static int sampleSizePerLayer; //input
private static int numberOfLayers; //input
//...
private static NumberFormat numberFormat = NumberFormat.getInstance();
static {
numberFormat.setMaximumFractionDigits(2);
numberFormat.setMinimumFractionDigits(2);
}
//...
private Stage stage;
private final Group root = new Group();
private List<int[]>[] sampleCoordinatesByLayer;
private ActivationShape[][] shapesByLayerAndSample;
private final List<ActivationShape> allActivationShapes = new ArrayList<>();
private volatile CountDownLatch needUpdate;
//....
private class ActivationShape extends Sphere {
private final int[] coordinatesInLayerInput;
private final int layerIndex;
public ActivationShape(int layerIndex, int sampleIndex, int[] coordinates) {
super(10);
this.layerIndex = layerIndex;
//double d=network.getLayer(layerIndex).input().getDouble(coordinates);
this.coordinatesInLayerInput = coordinates;
double deltaWidth = WIDTH / sampleSizePerLayer;
double deltaHeight = HEIGHT / numberOfLayers;
setTranslateX(deltaWidth / 2 + sampleIndex * deltaWidth);
setTranslateY(deltaHeight / 2 + layerIndex * deltaHeight);
updateFromNeuralInput();
allActivationShapes.add(this);
}
public void updateFromNeuralInput() {
double d = 0;
INDArray input = network.getLayer(layerIndex).input();
d = -10 * input.getDouble(coordinatesInLayerInput);
double hue = 360.0 / (1 + Math.exp(d));
PhongMaterial material = new PhongMaterial(Color.hsb(hue, 1, 1));
this.setMaterial(material);
// System.out.println("d = " + d +", hue = " + hue);
}
}
//......
public ActivationsViewer() {
}
public static void initialize(MultiLayerNetwork network, int sampleSizePerLayer) {
ActivationsViewer.network = network;
ActivationsViewer.sampleSizePerLayer = sampleSizePerLayer;
ActivationsViewer.numberOfLayers = network.getnLayers();
}
private void makeLayerViews() {
sampleCoordinatesByLayer = new List[numberOfLayers];
chooseSampleCoordinates();
makeInputShapes();
}
//-----
private void makeInputShapes() {
shapesByLayerAndSample = new ActivationShape[numberOfLayers][sampleSizePerLayer];
for (int layerIndex = 0; layerIndex < numberOfLayers; layerIndex++) {
for (int i = 0; i < sampleSizePerLayer; i++) {
ActivationShape activationShape = new ActivationShape(layerIndex, i, sampleCoordinatesByLayer[layerIndex].get(i));
shapesByLayerAndSample[layerIndex][i] = activationShape;
root.getChildren().add(activationShape);
}
}
}
//.....
private void chooseSampleCoordinates() { // But the 0th index must be 0, because some mini-batches are smaller
Random random = new Random(1234);
for (int layerIndex = 0; layerIndex < sampleCoordinatesByLayer.length; layerIndex++) {
List<int[]> list = new ArrayList<>();
sampleCoordinatesByLayer[layerIndex] = list;
Layer layer = network.getLayer(layerIndex);
INDArray input = layer.input();
int[] shape = input.shape();
int repeats = 0; // used to avoid the (rare) case in which we try to add many duplicates.
for (int r = 0; r < sampleSizePerLayer; r++) {
int[] sampleCoordinates = new int[shape.length];
for (int i = 1; i < shape.length; i++) {
sampleCoordinates[i] = random.nextInt(shape[i]);
}
if (isDuplicate(sampleCoordinates, list) && repeats < 20) {
r--;
repeats++;
} else {
list.add(sampleCoordinates);
}
}
}
showSampleCoordinates();
}
private void showSampleCoordinates() {
System.out.println("Sample coordindates");
for (int layer = 0; layer < sampleCoordinatesByLayer.length; layer++) {
System.out.println(layer);
for (int[] sample : sampleCoordinatesByLayer[layer]) {
System.out.println(" " + ActivationsListener.toString(sample));
}
}
System.out.println();
}
private static boolean isDuplicate(int[] array, List<int[]> list) {
for (int[] ar : list) {
if (equal(ar, array)) {
return true;
}
}
return false;
}
private static boolean equal(int[] a1, int[] a2) {
if (a1.length != a2.length) {
throw new IllegalStateException();
}
for (int i = 0; i < a1.length; i++) {
if (a1[i] != a2[i]) {
return false;
}
}
return true;
}
/*
ActivationsListener layers:
0: 520 params, input shape = [64,1,28,28], activation shape = [64,20,24,24] 64 is batch size, 20 is nOut, 20*24*24=11520
1: 0 params, input shape = [64,20,24,24], activation shape = [64,20,12,12] max pooling, 20*12*12=2880
2: 25050 params, input shape = [64,20,12,12], activation shape = [64,50,8,8] 64 is batch size, 25 is nOut, 50*8*8=3200
3: 0 params, input shape = [64,50,8,8], activation shape = [64,50,4,4] max pooling, 5*4*4=800
4: 400500 params, input shape = [64,800], activation shape = [64,500]
5: 5010 params, input shape = [64,500], activation shape = [64,10]
*/
private long start = System.currentTimeMillis();
public void requestIterationUpdate(int iteration) {
needUpdate = new CountDownLatch(1); // The method animate() returns unless needUpdate is non-null.
try {
needUpdate.await();
} catch (InterruptedException exc) {
System.err.println("Warning: interrupted in requestIterationUpdate of ActivationsViewer");
Thread.interrupted();
}
// TODO: I think is there a race condition here. It's possible that the UI thread runs again before we set
// needUpdate to null. We could overcome this by copying the state into shared variables and letting
// the UI thread reference those variables.
needUpdate = null;
}
private void updateInternal() {
for (ActivationShape activationShape : allActivationShapes) {
activationShape.updateFromNeuralInput();
}
root.requestLayout();
}
private void animate() {
final AnimationTimer timer = new AnimationTimer() {
@Override
public void handle(long nowInNanoSeconds) {
if (needUpdate != null) {
updateInternal();
needUpdate.countDown();
}
}
};
timer.start();
}
@Override
public void start(Stage primaryStage) throws Exception {
stage = new Stage();
stage.initModality(Modality.NONE);
stage.setOnCloseRequest(r -> System.exit(0));
stage.setTitle("Network viewer");
stage.setMinWidth(850);
stage.setMinHeight(600);
Scene scene = new Scene(root, WIDTH, HEIGHT, true);
scene.setFill(Color.DIMGREY);
stage.setScene(scene);
staticInstance = this;
makeLayerViews();
stage.show();
animate();
}
//--------------------------------------------
}