/* * Copyright 2004-2010 Information & Software Engineering Group (188/1) * Institute of Software Technology and Interactive Systems * Vienna University of Technology, Austria * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.ifs.tuwien.ac.at/dm/somtoolbox/license.html * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package at.tuwien.ifs.somtoolbox.apps.viewer; import java.awt.BorderLayout; import java.awt.Container; import java.awt.FlowLayout; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import java.awt.image.BufferedImage; import java.util.Hashtable; import java.util.logging.Logger; import javax.swing.ButtonGroup; import javax.swing.JButton; import javax.swing.JCheckBox; import javax.swing.JFrame; import javax.swing.JLabel; import javax.swing.JPanel; import javax.swing.JRadioButton; import javax.swing.JSpinner; import javax.swing.SpinnerNumberModel; import javax.swing.event.ChangeEvent; import javax.swing.event.ChangeListener; import cern.colt.matrix.DoubleMatrix2D; import cern.colt.matrix.doublealgo.Statistic; import cern.colt.matrix.impl.DenseDoubleMatrix2D; import cern.colt.matrix.linalg.Algebra; import cern.jet.math.Functions; import edu.umd.cs.piccolo.PNode; import edu.umd.cs.piccolo.nodes.PImage; import edu.umd.cs.piccolo.nodes.PPath; import edu.umd.cs.piccolo.nodes.PText; import at.tuwien.ifs.somtoolbox.SOMToolboxException; import at.tuwien.ifs.somtoolbox.apps.viewer.fileutils.ExportUtils; import at.tuwien.ifs.somtoolbox.data.AbstractSOMLibSparseInputData; import at.tuwien.ifs.somtoolbox.data.InputData; import at.tuwien.ifs.somtoolbox.data.InputDatum; import at.tuwien.ifs.somtoolbox.data.SOMLibTemplateVector; import at.tuwien.ifs.somtoolbox.data.SOMVisualisationData; import at.tuwien.ifs.somtoolbox.data.SharedSOMVisualisationData; import at.tuwien.ifs.somtoolbox.layers.GrowingLayer; import at.tuwien.ifs.somtoolbox.layers.Unit; import at.tuwien.ifs.somtoolbox.models.GrowingSOM; import at.tuwien.ifs.somtoolbox.properties.PropertiesException; import at.tuwien.ifs.somtoolbox.properties.SOMProperties; import at.tuwien.ifs.somtoolbox.visualization.ComponentPlanesVisualizer; import at.tuwien.ifs.somtoolbox.visualization.Palettes; /** * This class implements ordered display and clustering of SOM Component Planes. The components planes are transformed * to vectors, and are subsequently either displayed in their order, or clustered on a new SOM. * * @author Arnaud Moreau * @author Peter Vorlaufer * @author Rudolf Mayer * @version $Id: ComponentPlaneClusteringFrame.java 3984 2010-12-21 16:30:25Z frank $ */ public class ComponentPlaneClusteringFrame extends JFrame implements ActionListener, ChangeListener { private static final String CLUSTER = "Clustering"; private static final String DISPLAY = "Display ordered"; private static final long serialVersionUID = 1L; private AbstractSOMLibSparseInputData input; private SOMProperties props; private String[] labels; private GrowingSOM orginalSom; private GenericPNodeScrollPane pane; private JSpinner spinnerXSize; private JSpinner spinnerYSize; private SpinnerNumberModel spinnerNumberModelXSize; private SpinnerNumberModel spinnerNumberModelYSize; /** A cache for already trained SOMs. */ private Hashtable<String, ComponentPlaneClustering> clusteredMapCache = new Hashtable<String, ComponentPlaneClustering>(); private PNode unclusteredComponentPNodeWithNames; private PNode unclusteredComponentPNodeWithOutNames; final int uHeight = MapPNode.DEFAULT_UNIT_HEIGHT; final int uWidth = MapPNode.DEFAULT_UNIT_WIDTH; private int dim; private SOMLibTemplateVector tv; private SOMViewer somViewer; private CommonSOMViewerStateData state; private ButtonGroup buttons; private int padding = 12; private JCheckBox checkboxShowComponentNames; public ComponentPlaneClusteringFrame(SOMViewer somViewer, GrowingSOM orginalSom, SOMLibTemplateVector tv) throws SOMToolboxException { super("Component Plane Clustering"); this.orginalSom = orginalSom; this.somViewer = somViewer; this.tv = tv; GrowingLayer layer = orginalSom.getLayer(); dim = tv.dim(); // create covariance matrix from CPs DoubleMatrix2D cov = this.getCov(layer); labels = new String[dim]; InputDatum[] newData = new InputDatum[dim]; // extract feature names and save new training vectors for (int i = 0; i < dim; i++) { labels[i] = tv.getLabel(i); newData[i] = new InputDatum(labels[i], cov.viewColumn(i), cov.viewColumn(i).cardinality()); } // compute new x=y SOM Size int newSOMAxisSize = (int) Math.ceil(Math.sqrt(dim)) + 1; // create new Input Data for the SOM input = AbstractSOMLibSparseInputData.create(newData, null); tv = new SOMLibTemplateVector(input.numVectors(), input.dim()); Container contentPane = getContentPane(); contentPane.setLayout(new BorderLayout()); JPanel topPanel = new JPanel(new FlowLayout(FlowLayout.LEFT)); JRadioButton radioButtonDisplay = new JRadioButton(DISPLAY); radioButtonDisplay.setActionCommand(DISPLAY); radioButtonDisplay.addActionListener(this); topPanel.add(radioButtonDisplay); JRadioButton radioButtonCluster = new JRadioButton(CLUSTER); radioButtonCluster.setActionCommand(CLUSTER); radioButtonCluster.addActionListener(this); topPanel.add(radioButtonCluster); topPanel.add(new JLabel("xSize")); spinnerNumberModelXSize = new SpinnerNumberModel(newSOMAxisSize, 1, 50, 1); spinnerXSize = new JSpinner(spinnerNumberModelXSize); spinnerXSize.setEnabled(false); spinnerXSize.addChangeListener(this); topPanel.add(spinnerXSize); topPanel.add(new JLabel("ySize")); spinnerNumberModelYSize = new SpinnerNumberModel(newSOMAxisSize, 1, 50, 1); spinnerYSize = new JSpinner(spinnerNumberModelYSize); spinnerYSize.setEnabled(false); spinnerYSize.addChangeListener(this); topPanel.add(spinnerYSize); buttons = new ButtonGroup(); buttons.add(radioButtonDisplay); buttons.add(radioButtonCluster); radioButtonDisplay.setSelected(true); JButton buttonSave = new JButton("Save"); buttonSave.setToolTipText("Save the component plane pane to an image file"); buttonSave.addActionListener(new ActionListener() { @Override public void actionPerformed(ActionEvent e) { ExportUtils.saveMapPaneAsImage(getParent(), ComponentPlaneClusteringFrame.this.state.getFileChooser(), pane, "Save MapPane as PNG"); } }); topPanel.add(buttonSave); checkboxShowComponentNames = new JCheckBox("Show component names", true); checkboxShowComponentNames.addActionListener(this); topPanel.add(checkboxShowComponentNames); contentPane.add(topPanel, BorderLayout.NORTH); state = new CommonSOMViewerStateData(somViewer.getSOMViewerState()); state.inputDataObjects = new SharedSOMVisualisationData(); state.inputDataObjects.setData(SOMVisualisationData.TEMPLATE_VECTOR, tv); unclusteredComponentPNodeWithOutNames = createUnclusteredPane(somViewer, tv, layer, false); unclusteredComponentPNodeWithNames = createUnclusteredPane(somViewer, tv, layer, true); pane = new GenericPNodeScrollPane(state, unclusteredComponentPNodeWithNames); // Set initial pane size... pane.setPreferredSize(unclusteredComponentPNodeWithNames.getFullBounds().getBounds().getSize()); contentPane.add(pane, BorderLayout.CENTER); } private ComponentPlaneClustering createClusteredPane(SOMViewer parent, SOMLibTemplateVector tv, GrowingLayer layer) throws SOMToolboxException { int xSize = spinnerNumberModelXSize.getNumber().intValue(); int ySize = spinnerNumberModelYSize.getNumber().intValue(); // check if SOM size can hold all CPs if (xSize * ySize < layer.getDim()) { throw new SOMToolboxException("Size of map (" + xSize + "x" + ySize + ") can't be smaller than number of dimensions (" + layer.getDim() + ") !"); } // specify Properties of new SOM try { int iterations = Math.max(1000, input.numVectors() * 100); props = new SOMProperties(xSize, ySize, 7, 0, iterations, 0.7, 0, 1, "", true); } catch (PropertiesException pe) { pe.printStackTrace(); } // create Layer and train it GrowingSOM cpsom = new GrowingSOM(false, props, input); cpsom.train(input, props); // check if there are multiple items on one Unit reStructureMap(cpsom); CommonSOMViewerStateData state = new CommonSOMViewerStateData(parent.getSOMViewerState()); state.inputDataObjects = new SharedSOMVisualisationData(); state.inputDataObjects.setData(SOMVisualisationData.TEMPLATE_VECTOR, tv); return new ComponentPlaneClustering(cpsom, makeComponentPNode(createComponentPlanesVisualizer(state), cpsom)); } public PNode makeComponentPNode(ComponentPlanesVisualizer visualizer, GrowingSOM cpsom) throws SOMToolboxException { final GrowingLayer layer = cpsom.getLayer(); PNode componentImages = createPNode(layer.getXSize(), layer.getYSize()); // make a map grid for (int x = 0; x < layer.getXSize(); x++) { for (int y = 0; y < layer.getYSize(); y++) { PPath rect = PPath.createRectangle((float) x * (uWidth + padding), (float) y * (uHeight + padding), (uWidth + padding), (uHeight + padding)); componentImages.addChild(rect); } } // draw all component images for (int i = 0; i < labels.length; i++) { Unit u = layer.getUnitForDatum(labels[i]); createComponentImage(visualizer, componentImages, i, u.getXPos(), u.getYPos(), false); } return componentImages; } private PNode createUnclusteredPane(SOMViewer parent, SOMLibTemplateVector tv, GrowingLayer layer, boolean showComponentNames) throws SOMToolboxException { int neededXSize = (int) Math.ceil(Math.sqrt(dim)); int neededYSize = (int) Math.floor(Math.sqrt(dim)); PNode componentImages = createPNode(neededXSize, neededYSize); // draw all component images for (int i = 0; i < labels.length; i++) { int xPos = i % neededXSize; int yPos = i / neededXSize; createComponentImage(createComponentPlanesVisualizer(state), componentImages, i, xPos, yPos, showComponentNames); } return componentImages; } private ComponentPlanesVisualizer createComponentPlanesVisualizer(CommonSOMViewerStateData state) { ComponentPlanesVisualizer vis = new ComponentPlanesVisualizer(); vis.setInputObjects(state.inputDataObjects); vis.setPalette(Palettes.getPaletteByName("RGB256")); return vis; } private void createComponentImage(ComponentPlanesVisualizer visualizer, PNode componentImages, int componentIndex, int xPos, int yPos, boolean showComponentNames) throws SOMToolboxException { BufferedImage bimg = visualizer.createVisualization(0, componentIndex, orginalSom, orginalSom.getLayer().getXSize() * 10, orginalSom.getLayer().getYSize() * 10); int textHeight = 15; if (showComponentNames) { // also display component names? PText componentName = new PText(labels[componentIndex]); double width2 = componentName.getWidth(); componentName.setHeight(textHeight); componentImages.addChild(componentName); componentName.moveToFront(); componentName.translate((uWidth - width2) / 2 + (uWidth + padding) * xPos + padding / 2, (uHeight + padding + textHeight) * yPos + padding / 2 - 0.2 * textHeight); } PImage img = new PImage(bimg); img.addAttribute("tooltip", "Component #" + componentIndex + ", '" + labels[componentIndex] + "'"); img.setWidth(uWidth); img.setHeight(uHeight); componentImages.addChild(img); img.moveToFront(); img.translate((uWidth + padding) * xPos + padding / 2, (uHeight + padding) * yPos + (showComponentNames ? textHeight * (yPos + 1) : 0) + padding / 2); } private PNode createPNode(int xSize, int ySize) { PNode componentImages = new PNode(); componentImages.setWidth(xSize * (uWidth + padding) + 2); componentImages.setHeight(ySize * (uHeight + padding) + 2); return componentImages; } @Override public void actionPerformed(ActionEvent e) { update(); } @Override public void stateChanged(ChangeEvent e) { update(); // set the minimum spinner value to prevent a map size smaller than the number of dimensions spinnerNumberModelXSize.setMinimum((int) Math.ceil(dim / spinnerNumberModelYSize.getNumber().doubleValue())); spinnerNumberModelYSize.setMinimum((int) Math.ceil(dim / spinnerNumberModelXSize.getNumber().doubleValue())); } private void update() { if (buttons.getSelection().getActionCommand() == DISPLAY) { if (checkboxShowComponentNames.isSelected()) { pane.setPNode(unclusteredComponentPNodeWithNames); } else { pane.setPNode(unclusteredComponentPNodeWithOutNames); } } else { String key = spinnerNumberModelXSize.getNumber() + "x" + spinnerNumberModelYSize.getNumber(); if (clusteredMapCache.get(key) == null) { try { clusteredMapCache.put(key, createClusteredPane(somViewer, tv, orginalSom.getLayer())); pane.setPNode(clusteredMapCache.get(key).vis); } catch (SOMToolboxException e1) { e1.printStackTrace(); Logger.getLogger("at.tuwien.ifs.somtoolbox").severe( "Error creating component plane clustering: " + e1.getMessage()); } } else { pane.setPNode(clusteredMapCache.get(key).vis); } } boolean enableSpinner = buttons.getSelection().getActionCommand() != DISPLAY; spinnerXSize.setEnabled(enableSpinner); spinnerYSize.setEnabled(enableSpinner); } private DoubleMatrix2D getCov(GrowingLayer layer) { // serialise CPs in new matrix (rows = number of units, columns = dimension) DenseDoubleMatrix2D matrix = new DenseDoubleMatrix2D(layer.getXSize() * layer.getYSize(), layer.getDim()); for (int i = 0; i < layer.getDim(); i++) { double[][] cp = layer.getComponentPlane(i); for (int n = 0; n < layer.getXSize(); n++) { for (int m = 0; m < layer.getYSize(); m++) { matrix.setQuick(n * layer.getYSize() + m, i, cp[n][m]); } } } // compute covariance matrix DoubleMatrix2D covariance = Statistic.covariance(matrix); // normalise covariance matrix, to have diagonals of 1 DoubleMatrix2D diagonal = new DenseDoubleMatrix2D(covariance.columns(), 1); for (int i = 0; i < covariance.columns(); i++) { diagonal.setQuick(i, 0, covariance.getQuick(i, i)); } Algebra algebra = new Algebra(); DoubleMatrix2D mult = algebra.mult(diagonal, algebra.transpose(diagonal)); mult.assign(Functions.sqrt); covariance.assign(mult, Functions.div); return covariance; } private void reStructureMap(GrowingSOM cpsom) { int doubleUnits = 0; int iter = 0; InputData d = cpsom.getLayer().getData(); do { Unit[] units = cpsom.getLayer().getAllUnits(); doubleUnits = 0; for (Unit unit : units) { String[] lab = unit.getMappedInputNames(); if (lab != null) { if (lab.length > 1) { doubleUnits++; for (int j = 1; j < lab.length; j++) { InputDatum da = d.getInputDatum(lab[j]); Unit[] winners = cpsom.getLayer().getWinners(da, 2 + iter); winners[1 + iter].addMappedInput(da, true); unit.removeMappedInput(lab[j]); Logger.getLogger("at.tuwien.ifs.somtoolbox").info( "Moving " + lab[j] + " from " + unit + " to " + winners[1 + iter]); } } } } iter++; } while (doubleUnits > 0); } private class ComponentPlaneClustering { private GrowingSOM som; private PNode vis; public ComponentPlaneClustering(GrowingSOM som, PNode vis) { this.som = som; this.vis = vis; } } }