package hex; import javafx.application.Platform; import javafx.beans.value.ChangeListener; import javafx.beans.value.ObservableValue; import javafx.collections.FXCollections; import javafx.collections.ObservableList; import javafx.embed.swing.JFXPanel; import javafx.event.ActionEvent; import javafx.event.EventHandler; import javafx.scene.Scene; import javafx.scene.chart.LineChart; import javafx.scene.chart.NumberAxis; import javafx.scene.control.Button; import javafx.scene.control.CheckBox; import javafx.scene.control.ScrollPane; import javafx.scene.control.ToolBar; import javafx.scene.layout.BorderPane; import javafx.scene.layout.HBox; import javafx.scene.layout.VBox; import javafx.stage.Stage; import javax.swing.*; import java.util.ArrayList; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; public class Histograms extends LineChart { private static final int SLICES = 64; private static final ArrayList<Histograms> _instances = new ArrayList<Histograms>(); private static final ScheduledExecutorService _executor = Executors.newSingleThreadScheduledExecutor(); private static CheckBox _auto; private final float[] _data; private final ObservableList<Data<Float, Float>> _list = FXCollections.observableArrayList(); public static void init() { final CountDownLatch latch = new CountDownLatch(1); SwingUtilities.invokeLater(new Runnable() { public void run() { initFromSwingThread(); latch.countDown(); } }); try { latch.await(); } catch( InterruptedException e ) { throw new RuntimeException(e); } } static void initFromSwingThread() { new JFXPanel(); // initializes JavaFX environment } public static void build(final Layer[] ls) { Platform.runLater(new Runnable() { @Override public void run() { VBox v = new VBox(); for( int i = ls.length - 1; i > 0; i-- ) { HBox h = new HBox(); h.getChildren().add(new Histograms("Layer " + i + " weight", ls[i]._w)); h.getChildren().add(new Histograms("Layer " + i + " bias", ls[i]._b)); h.getChildren().add(new Histograms("Layer " + i + " activity", ls[i]._a)); h.getChildren().add(new Histograms("Layer " + i + " error", ls[i]._e)); h.getChildren().add(new Histograms("Layer " + i + " weight momentum", ls[i]._wm)); h.getChildren().add(new Histograms("Layer " + i + " bias momentum", ls[i]._bm)); v.getChildren().add(h); } Stage stage = new Stage(); BorderPane root = new BorderPane(); ToolBar toolbar = new ToolBar(); Button refresh = new Button("Refresh"); refresh.setOnAction(new EventHandler<ActionEvent>() { @Override public void handle(ActionEvent e) { refresh(); } }); toolbar.getItems().add(refresh); _auto = new CheckBox("Auto"); _auto.selectedProperty().addListener(new ChangeListener<Boolean>() { public void changed(ObservableValue<? extends Boolean> ov, Boolean old_val, Boolean new_val) { refresh(); } }); toolbar.getItems().add(_auto); root.setTop(toolbar); ScrollPane scroll = new ScrollPane(); scroll.setContent(v); root.setCenter(scroll); Scene scene = new Scene(root); stage.setScene(scene); stage.setWidth(2450); stage.setHeight(1500); stage.show(); scene.getWindow().onCloseRequestProperty().addListener(new ChangeListener() { @Override public void changed(ObservableValue arg0, Object arg1, Object arg2) { _auto.selectedProperty().set(false); } }); refresh(); } }); } public Histograms(String title, float[] data) { super(new NumberAxis(), new NumberAxis()); _data = data; ObservableList<Series<Float, Float>> series = FXCollections.observableArrayList(); for( int i = 0; i < SLICES; i++ ) _list.add(new Data<Float, Float>(0f, 0f)); series.add(new LineChart.Series<Float, Float>(title, _list)); setData(series); setPrefWidth(600); setPrefHeight(250); _instances.add(this); } public Histograms(String title, double[] data) { super(new NumberAxis(), new NumberAxis()); _data = new float[data.length]; for (int i=0; i<data.length; ++i) _data[i] = (float)data[i]; ObservableList<Series<Float, Float>> series = FXCollections.observableArrayList(); for( int i = 0; i < SLICES; i++ ) _list.add(new Data<Float, Float>(0f, 0f)); series.add(new LineChart.Series<Float, Float>(title, _list)); setData(series); setPrefWidth(600); setPrefHeight(250); _instances.add(this); } static void refresh() { for( Histograms h : _instances ) { if( h._data != null ) { float[] data = h._data.clone(); float min = Float.MAX_VALUE, max = Float.MIN_VALUE; for( int i = 0; i < data.length; i++ ) { max = Math.max(max, data[i]); min = Math.min(min, data[i]); } int[] counts = new int[SLICES]; float inc = (max - min) / (SLICES - 1); for( int i = 0; i < data.length; i++ ) counts[(int) Math.floor((data[i] - min) / inc)]++; for( int i = 0; i < SLICES; i++ ) { Data<Float, Float> point = h._list.get(i); point.setXValue(min + inc * i); point.setYValue((float) counts[i] / data.length); } } } if( _auto.selectedProperty().get() ) { _executor.schedule(new Runnable() { @Override public void run() { Platform.runLater(new Runnable() { @Override public void run() { refresh(); } }); } }, 1000, TimeUnit.MILLISECONDS); } } }