/* * Encog(tm) Workbench v3.4 * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-workbench * * Copyright 2008-2016 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.workbench.tabs.visualize.structure; import java.awt.BorderLayout; import java.awt.Color; import java.awt.Dimension; import java.awt.FlowLayout; import java.awt.Paint; import java.awt.Point; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import java.awt.geom.Point2D; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import javax.swing.BorderFactory; import javax.swing.JButton; import javax.swing.JPanel; import javax.swing.border.Border; import org.apache.commons.collections15.Transformer; import org.encog.engine.network.activation.ActivationBipolarSteepenedSigmoid; import org.encog.engine.network.activation.ActivationClippedLinear; import org.encog.engine.network.activation.ActivationGaussian; import org.encog.engine.network.activation.ActivationSIN; import org.encog.neural.neat.training.NEATGenome; import org.encog.neural.neat.training.NEATLinkGene; import org.encog.neural.neat.training.NEATNeuronGene; import org.encog.workbench.WorkBenchError; import org.encog.workbench.tabs.EncogCommonTab; import edu.uci.ics.jung.algorithms.layout.StaticLayout; import edu.uci.ics.jung.graph.Graph; import edu.uci.ics.jung.graph.SparseMultigraph; import edu.uci.ics.jung.graph.util.EdgeType; import edu.uci.ics.jung.visualization.GraphZoomScrollPane; import edu.uci.ics.jung.visualization.Layer; import edu.uci.ics.jung.visualization.VisualizationViewer; import edu.uci.ics.jung.visualization.control.AbstractModalGraphMouse; import edu.uci.ics.jung.visualization.control.CrossoverScalingControl; import edu.uci.ics.jung.visualization.control.DefaultModalGraphMouse; import edu.uci.ics.jung.visualization.control.ScalingControl; import edu.uci.ics.jung.visualization.decorators.ToStringLabeller; import edu.uci.ics.jung.visualization.renderers.Renderer; public class GenomeStructureTab extends EncogCommonTab { private VisualizationViewer<DrawnNeuron, DrawnConnection> vv; private NEATGenome genome; public GenomeStructureTab(NEATGenome genome) { super(null); this.genome = genome; // Graph<V, E> where V is the type of the vertices // and E is the type of the edges Graph<DrawnNeuron, DrawnConnection> g = null; g = buildGraph(genome); if (g == null) { throw new WorkBenchError("Can't visualize genome"); } Transformer<DrawnNeuron, Point2D> staticTranformer = new Transformer<DrawnNeuron, Point2D>() { public Point2D transform(DrawnNeuron n) { int x = (int) (n.getX() * 600); int y = (int) (n.getY() * 300); Point2D result = new Point(x + 32, y); return result; } }; Transformer<DrawnNeuron, Paint> vertexPaint = new Transformer<DrawnNeuron, Paint>() { public Paint transform(DrawnNeuron neuron) { switch (neuron.getType()) { case Bias: return Color.yellow; case Input: return Color.white; case Output: return Color.green; case Context: return Color.cyan; case Linear: return Color.blue; case Sigmoid: return Color.magenta; case Gaussian: return Color.cyan; case SIN: return Color.gray; default: return Color.red; } } }; Transformer<DrawnConnection, Paint> edgePaint = new Transformer<DrawnConnection, Paint>() { public Paint transform(DrawnConnection connection) { if (connection.isContext()) { return Color.lightGray; } else { return Color.black; } } }; // The Layout<V, E> is parameterized by the vertex and edge types StaticLayout<DrawnNeuron, DrawnConnection> layout = new StaticLayout<DrawnNeuron, DrawnConnection>( g, staticTranformer); layout.setSize(new Dimension(5000, 5000)); // sets the initial size of // the space // The BasicVisualizationServer<V,E> is parameterized by the edge types // BasicVisualizationServer<DrawnNeuron, DrawnConnection> vv = new // BasicVisualizationServer<DrawnNeuron, DrawnConnection>( // layout); // Dimension d = new Dimension(600,600); vv = new VisualizationViewer<DrawnNeuron, DrawnConnection>(layout); // vv.setPreferredSize(d); //Sets the viewing area size vv.getRenderer().getVertexLabelRenderer() .setPosition(Renderer.VertexLabel.Position.CNTR); vv.getRenderContext().setVertexLabelTransformer(new ToStringLabeller()); vv.getRenderContext().setVertexFillPaintTransformer(vertexPaint); vv.getRenderContext().setEdgeDrawPaintTransformer(edgePaint); vv.getRenderContext().setArrowDrawPaintTransformer(edgePaint); vv.getRenderContext().setArrowFillPaintTransformer(edgePaint); vv.setVertexToolTipTransformer(new ToStringLabeller()); vv.setVertexToolTipTransformer(new Transformer<DrawnNeuron, String>() { public String transform(DrawnNeuron edge) { return edge.getToolTip(); } }); vv.setEdgeToolTipTransformer(new Transformer<DrawnConnection, String>() { public String transform(DrawnConnection edge) { return edge.getToolTip(); } }); final GraphZoomScrollPane panel = new GraphZoomScrollPane(vv); this.setLayout(new BorderLayout()); add(panel, BorderLayout.CENTER); final AbstractModalGraphMouse graphMouse = new DefaultModalGraphMouse(); vv.setGraphMouse(graphMouse); vv.addKeyListener(graphMouse.getModeKeyListener()); final ScalingControl scaler = new CrossoverScalingControl(); JButton plus = new JButton("+"); plus.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { scaler.scale(vv, 1.1f, vv.getCenter()); } }); JButton minus = new JButton("-"); minus.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { scaler.scale(vv, 1 / 1.1f, vv.getCenter()); } }); JButton reset = new JButton("reset"); reset.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { vv.getRenderContext().getMultiLayerTransformer() .getTransformer(Layer.LAYOUT).setToIdentity(); vv.getRenderContext().getMultiLayerTransformer() .getTransformer(Layer.VIEW).setToIdentity(); } }); JPanel controls = new JPanel(); controls.setLayout(new FlowLayout(FlowLayout.LEFT)); controls.add(plus); controls.add(minus); controls.add(reset); Border border = BorderFactory.createEtchedBorder(); controls.setBorder(border); add(controls, BorderLayout.NORTH); add(new LegendPanel(true),BorderLayout.SOUTH); } private int calculateDepths(Map<Integer, DrawnNeuron> neuronMap) { List<DrawnNeuron> outputList = new ArrayList<DrawnNeuron>(); int maxDepth = 0; int maxOutputDepth = 0; for (int pass = 0; pass < 1; pass++) { boolean done = false; while (!done ) { done = true; for (NEATLinkGene neatLinkGene : genome.getLinksChromosome()) { if (neatLinkGene.getFromNeuronID() != neatLinkGene .getToNeuronID()) { DrawnNeuron fromNeuron = neuronMap .get((int) neatLinkGene.getFromNeuronID()); DrawnNeuron toNeuron = neuronMap.get((int) neatLinkGene .getToNeuronID()); // do not calculate a depth if the from is undefined if (fromNeuron.getDepth() != -1) { // if the to is depth 0 (bias or input) if (toNeuron.getDepth() != 0) { if( toNeuron.getDepth()==-1) { done = false; } int depth = fromNeuron.getDepth() + 1; toNeuron.setDepth(Math.max(toNeuron.getDepth(), depth)); maxDepth = Math.max(depth, maxDepth); if (toNeuron.getType() == DrawnNeuronType.Output) { maxOutputDepth = Math.max(maxOutputDepth, depth); outputList.add(toNeuron); } } } } } } } maxDepth++; // all output at the same level for (DrawnNeuron neuron : outputList) { neuron.setDepth(maxDepth); } // handle any unassigned neurons, these are hidden neurons with no input. // put them at depth zero, as they are basically bias-like neurons. for (NEATLinkGene neatLinkGene : genome.getLinksChromosome()) { DrawnNeuron fromNeuron = neuronMap .get((int) neatLinkGene.getFromNeuronID()); DrawnNeuron toNeuron = neuronMap.get((int) neatLinkGene .getToNeuronID()); if( fromNeuron.getDepth()==-1 ) { fromNeuron.setDepth(0); } if( toNeuron.getDepth()==-1 ) { toNeuron.setDepth(0); } } return maxDepth; } private void calculateXY(List<DrawnNeuron> neurons, int maxDepth) { int[] layerTotal = new int[maxDepth + 1]; int[] layerCurrent = new int[maxDepth + 1]; for (DrawnNeuron neuron : neurons) { if( neuron.getDepth()<0 ) { neuron.setDepth(0); } layerTotal[neuron.getDepth()]++; } for (DrawnNeuron neuron : neurons) { layerCurrent[neuron.getDepth()]++; neuron.setX(neuron.getDepth() * (1.0 / layerTotal.length)); neuron.setY(layerCurrent[neuron.getDepth()] * (1.0 / layerTotal[neuron.getDepth()])); } } private Graph<DrawnNeuron, DrawnConnection> buildGraph(NEATGenome genome) { int inputCount = 1; int outputCount = 1; int hiddenCount = 1; int biasCount = 1; List<DrawnNeuron> neurons = new ArrayList<DrawnNeuron>(); Graph<DrawnNeuron, DrawnConnection> result = new SparseMultigraph<DrawnNeuron, DrawnConnection>(); List<DrawnNeuron> connections = new ArrayList<DrawnNeuron>(); Map<Integer, DrawnNeuron> neuronMap = new HashMap<Integer, DrawnNeuron>(); // place all the neurons for (NEATNeuronGene neuronGene : genome.getNeuronsChromosome()) { String name = ""; int depth = -1; DrawnNeuronType t = DrawnNeuronType.Hidden; switch (neuronGene.getNeuronType()) { case Bias: depth = 0; t = DrawnNeuronType.Bias; name = "B" + (biasCount++); break; case Input: depth = 0; t = DrawnNeuronType.Input; name = "I" + (inputCount++); break; case Output: t = DrawnNeuronType.Output; name = "O" + (outputCount++); break; case Hidden: if( neuronGene.getActivationFunction() instanceof ActivationClippedLinear) { t = DrawnNeuronType.Linear; } else if( neuronGene.getActivationFunction() instanceof ActivationBipolarSteepenedSigmoid) { t = DrawnNeuronType.Sigmoid; } else if( neuronGene.getActivationFunction() instanceof ActivationGaussian) { t = DrawnNeuronType.Gaussian; } else if( neuronGene.getActivationFunction() instanceof ActivationSIN) { t = DrawnNeuronType.SIN; } name = "H" + (hiddenCount++); break; } DrawnNeuron neuron = new DrawnNeuron(t, name); neurons.add(neuron); neuron.setDepth(depth); neuronMap.put((int) neuronGene.getId(), neuron); } // place all the connections for (NEATLinkGene neatLinkGene : genome.getLinksChromosome()) { if (neatLinkGene.isEnabled()) { DrawnNeuron fromNeuron = neuronMap.get((int) neatLinkGene .getFromNeuronID()); DrawnNeuron toNeuron = neuronMap.get((int) neatLinkGene .getToNeuronID()); DrawnConnection connection = new DrawnConnection(fromNeuron, toNeuron, neatLinkGene.getWeight()); fromNeuron.getOutbound().add(connection); toNeuron.getInbound().add(connection); } } int maxDepth = calculateDepths(neuronMap); calculateXY(neurons, maxDepth); for (DrawnNeuron neuron : neurons) { result.addVertex(neuron); for (DrawnConnection connection : neuron.getOutbound()) { result.addEdge(connection, connection.getFrom(), connection.getTo(), EdgeType.DIRECTED); } } return result; } @Override public String getName() { return "NEAT Genome"; } }