package org.deeplearning4j.examples.dataexamples; import javafx.application.Application; import javafx.application.Platform; import javafx.scene.Scene; import javafx.scene.image.*; import javafx.scene.layout.HBox; import javafx.scene.paint.Color; import javafx.stage.Stage; import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.ui.api.UIServer; import org.deeplearning4j.ui.stats.StatsListener; import org.deeplearning4j.ui.storage.InMemoryStatsStorage; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.LossFunctions; /** * JavaFX application to show a neural network learning to draw an image. * Demonstrates how to feed an NN with externally originated data. * * This example uses JavaFX, which requires the Oracle JDK. Comment out this example if you use a different JDK. * OpenJDK and openjfx have been reported to work fine. * * TODO: sample does not shut down correctly. Process must be stopped from the IDE. * * @author Robert Altena * Many thanks to @tmanthey for constructive feedback and suggestions. */ public class ImageDrawer extends Application { private Image originalImage; //The source image displayed on the left. private WritableImage composition; // Destination image generated by the NN. private MultiLayerNetwork nn; // THE nn. private DataSet ds; //Training data generated (only once) from the Original, used to train. private INDArray xyOut; //x,y grid to calculate the output image. Needs to be calculated once, then re-used. /** * Training the NN and updating the current graphical output. */ private void onCalc(){ nn.fit(ds); drawImage(); Platform.runLater(this::onCalc); } @Override public void init(){ originalImage = new Image("/DataExamples/Mona_Lisa.png"); final int w = (int) originalImage.getWidth(); final int h = (int) originalImage.getHeight(); composition = new WritableImage(w, h); //Right image. ds = generateDataSet(originalImage); nn = createNN(); boolean fUseUI = false; // set to false if you do not want the web ui to track learning progress. if(fUseUI) { UIServer uiServer = UIServer.getInstance(); StatsStorage statsStorage = new InMemoryStatsStorage(); uiServer.attach(statsStorage); nn.setListeners(new StatsListener(statsStorage)); } // The x,y grid to calculate the NN output only needs to be calculated once. int numPoints = h * w; xyOut = Nd4j.zeros(numPoints, 2); for (int i = 0; i < w; i++) { double xp = scaleXY(i,w); for (int j = 0; j < h; j++) { int index = i + w * j; double yp = scaleXY(j,h); xyOut.put(index, 0, xp); //2 inputs. x and y. xyOut.put(index, 1, yp); } } drawImage(); } /** * Standard JavaFX start: Build the UI, display */ @Override public void start(Stage primaryStage) { final int w = (int) originalImage.getWidth(); final int h = (int) originalImage.getHeight(); final int zoom = 5; // Our images are a tad small, display them enlarged to have something to look at. ImageView iv1 = new ImageView(); //Left image iv1.setImage(originalImage); iv1.setFitHeight( zoom* h); iv1.setFitWidth(zoom*w); ImageView iv2 = new ImageView(); iv2.setImage(composition); iv2.setFitHeight( zoom* h); iv2.setFitWidth(zoom*w); HBox root = new HBox(); //build the scene. Scene scene = new Scene(root); root.getChildren().addAll(iv1, iv2); primaryStage.setTitle("Neural Network Drawing Demo."); primaryStage.setScene(scene); primaryStage.show(); Platform.setImplicitExit(true); //Allow JavaFX do to it's thing, Initialize the Neural network when it feels like it. Platform.runLater(this::onCalc); } public static void main( String[] args ) { launch(args); } /** * Build the Neural network. */ private static MultiLayerNetwork createNN() { int seed = 2345; int iterations = 25; //<-- Just the one iteration per call to fit. double learningRate = 0.1; int numInputs = 2; // x and y. int numHiddenNodes = 25; int numOutputs = 3 ; //R, G and B value. MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(seed) .iterations(iterations) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .learningRate(learningRate) .weightInit(WeightInit.RELU) .updater(Updater.NESTEROVS).momentum(0.9) .list() .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes) .activation(Activation.LEAKYRELU) .build()) .layer(1, new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes) .activation(Activation.LEAKYRELU) .build()) .layer(2, new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes) .activation(Activation.LEAKYRELU) .build()) .layer(3, new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes) .activation(Activation.LEAKYRELU) .build()) .layer(4, new OutputLayer.Builder(LossFunctions.LossFunction.L2) .activation(Activation.IDENTITY) .nIn(numHiddenNodes).nOut(numOutputs).build()) .pretrain(false).backprop(true).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); return net; } /** * Process a javafx Image to be consumed by DeepLearning4J * * @param img Javafx image to process * @return DeepLearning4J DataSet. */ private static DataSet generateDataSet(Image img) { int w = (int) img.getWidth(); int h = (int) img.getHeight(); int numPoints = h * w; PixelReader reader = img.getPixelReader(); INDArray xy = Nd4j.zeros(numPoints, 2); INDArray out = Nd4j.zeros(numPoints, 3); //Simplest implementation first. for (int i = 0; i < w; i++) { double xp = scaleXY(i,w); for (int j = 0; j < h; j++) { Color c = reader.getColor(i, j); int index = i + w * j; double yp = scaleXY(j,h); xy.put(index, 0, xp); //2 inputs. x and y. xy.put(index, 1, yp); out.put(index, 0, c.getRed()); //3 outputs. the RGB values. out.put(index, 1, c.getGreen()); out.put(index, 2, c.getBlue()); } } return new DataSet(xy, out); } /** * Make the Neural network draw the image. */ private void drawImage() { int w = (int) composition.getWidth(); int h = (int) composition.getHeight(); INDArray out = nn.output(xyOut); PixelWriter writer = composition.getPixelWriter(); for (int i = 0; i < w; i++) { for (int j = 0; j < h; j++) { int index = i + w * j; double red = capNNOutput(out.getDouble(index, 0)); double green = capNNOutput(out.getDouble(index, 1)); double blue = capNNOutput(out.getDouble(index, 2)); Color c = new Color(red, green, blue, 1.0); writer.setColor(i, j, c); } } } /** * Make sure the color values are >=0 and <=1 */ private static double capNNOutput(double x) { double tmp = (x<0.0) ? 0.0 : x; return (tmp > 1.0) ? 1.0 : tmp; } /** * scale x,y points */ private static double scaleXY(int i, int maxI){ return (double) i / (double) (maxI - 1) -0.5; } }