/* * RapidMiner * * Copyright (C) 2001-2008 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.gui.plotter.som; import java.awt.Color; import java.awt.Graphics; import java.awt.Graphics2D; import java.awt.GridLayout; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import java.awt.event.ItemEvent; import java.awt.event.ItemListener; import java.awt.geom.Rectangle2D; import java.awt.image.BufferedImage; import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; import java.util.Vector; import javax.swing.JButton; import javax.swing.JComboBox; import javax.swing.JComponent; import javax.swing.JLabel; import javax.swing.JPanel; import javax.swing.JProgressBar; import javax.swing.JTextField; import com.rapidminer.datatable.DataTable; import com.rapidminer.datatable.DataTableRow; import com.rapidminer.gui.plotter.ExamplePlotterPoint; import com.rapidminer.gui.plotter.PlotterAdapter; import com.rapidminer.gui.plotter.conditions.BasicPlotterCondition; import com.rapidminer.gui.plotter.conditions.PlotterCondition; import com.rapidminer.gui.tools.SwingTools; import com.rapidminer.tools.RandomGenerator; import com.rapidminer.tools.math.MathFunctions; import com.rapidminer.tools.math.som.KohonenNet; import com.rapidminer.tools.math.som.ProgressListener; import com.rapidminer.tools.math.som.RandomDataContainer; import com.rapidminer.tools.math.som.RitterAdaptation; /** * This is the main class for the SOMPlotter. It uses the KohonenNet class for generating a self-organizing map. * Different properties of the resulting map may be shown as background while the examples are shown as points. There * are different styled visualizations of the properties. * * @author Sebastian Land, Ingo Mierswa * @version $Id: SOMPlotter.java,v 1.12 2008/07/18 15:50:45 ingomierswa Exp $ */ public class SOMPlotter extends PlotterAdapter implements ProgressListener { private static final long serialVersionUID = -1936359032703929998L; private static final String[] MATRIX_TYPES = new String[] { "U-Matrix", "P-Matrix", "U*-Matrix" }; public static final int MATRIX_U = 0; public static final int MATRIX_P = 1; public static final int MATRIX_U_STAR = 2; protected static final int IMAGE_WIDTH = 400; protected static final int IMAGE_HEIGHT = 300; protected int[] dimensions = { 40, 30 }; private ArrayList<ExamplePlotterPoint> exampleCoordinates = new ArrayList<ExamplePlotterPoint>(); private boolean examplesApplied = false; private double[][] uMatrix; private double maxU; private double[][] pMatrix; private double maxP; private double[][] uStarMatrix; private double maxUStar; protected transient DataTable dataTable; protected boolean show = false; private String currentToolTip = null; private double toolTipX = 0.0d; private double toolTipY = 0.0d; private int showMatrix = 0; private int showColor = 0; protected int colorColumn = -1; private transient RandomDataContainer data = new RandomDataContainer(); protected transient KohonenNet net; private JButton approveButton = new JButton("Calculate"); private JComboBox matrixSelection = new JComboBox(MATRIX_TYPES); private JComboBox colorSelection = new JComboBox(new String[] { "Landscape", "GrayScale", "Fire and Ice" }); private JTextField roundSelection = new JTextField("25"); private JTextField radiusSelection = new JTextField("15"); private JTextField dimensionX = new JTextField("40"); private JTextField dimensionY = new JTextField("30"); private JProgressBar progressBar = new JProgressBar(); private boolean coloredPoints = true; private transient SOMMatrixColorizer[] colorizer = new SOMMatrixColorizer[] { new SOMLandscapeColorizer(), new SOMGreyColorizer(), new SOMFireColorizer() }; private int jitterAmount = 0; protected transient BufferedImage image = null; public SOMPlotter() { setBackground(Color.WHITE); approveButton.setToolTipText("Start the calculation of the SOM (may take a while)."); approveButton.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { startCalculation(true); } }); matrixSelection.setToolTipText("Select the matrix type which should be visualized."); matrixSelection.addItemListener(new ItemListener() { public void itemStateChanged(ItemEvent event) { showMatrix = matrixSelection.getSelectedIndex(); if (showMatrix < 0) { showMatrix = 0; } recalculateBackgroundImage(); repaint(); } }); colorSelection.setToolTipText("Select the color scheme used for the visualization of the matrix values."); colorSelection.addItemListener(new ItemListener() { public void itemStateChanged(ItemEvent event) { showColor = colorSelection.getSelectedIndex(); if (showColor < 0) { showColor = 0; } recalculateBackgroundImage(); repaint(); } }); progressBar.setToolTipText("Shows the progress of the SOM calculation."); progressBar.setMinimum(0); progressBar.setMaximum(100); this.setDoubleBuffered(true); } public void forcePlotGeneration() { startCalculation(false); } protected Object readResolve() { this.data = new RandomDataContainer(); this.colorizer = new SOMMatrixColorizer[] { new SOMLandscapeColorizer(), new SOMGreyColorizer(), new SOMFireColorizer() }; return this; } public void setColoredPoints(boolean coloredPoints) { this.coloredPoints = coloredPoints; } public void setMatrixType(int matrixType) { this.showMatrix = matrixType; } public void startCalculation(boolean threadMode) { show = false; try { dimensions[0] = Integer.parseInt(dimensionX.getText()); } catch (NumberFormatException ex) { SwingTools.showVerySimpleErrorMessage("Only numbers are allowed for SOM width."); return; } try { dimensions[1] = Integer.parseInt(dimensionY.getText()); } catch (NumberFormatException ex) { SwingTools.showVerySimpleErrorMessage("Only numbers are allowed for SOM height."); return; } int adaptationRadius = 15; try { adaptationRadius = Integer.parseInt(radiusSelection.getText()); } catch (NumberFormatException ex) { SwingTools.showVerySimpleErrorMessage("Only numbers are allowed for radius."); return; } int trainRounds = 25; try { trainRounds = Integer.parseInt(roundSelection.getText()); } catch (NumberFormatException ex) { SwingTools.showVerySimpleErrorMessage("Only numbers are allowed for number of training rounds."); return; } prepareSOM(dataTable, adaptationRadius, trainRounds, threadMode); } public PlotterCondition getPlotterCondition() { return new BasicPlotterCondition(); } public void paintComponent(Graphics graphics) { super.paintComponent(graphics); // painting only if approved if (show) { paintSom(graphics); } } public void paintSom(Graphics graphics) { // init graphics Graphics2D g = (Graphics2D) graphics; int pixWidth = getWidth() - 2 * MARGIN; int pixHeight = getHeight() - 2 * MARGIN; // painting background g.drawImage(this.image, MARGIN, MARGIN, pixWidth, pixHeight, Color.WHITE, null); // painting points drawPoints(g); // painting Legend drawLegend(graphics, this.dataTable, colorColumn); // paint Tooltip drawToolTip((Graphics2D) graphics); } protected void drawPoints(Graphics2D g) { int pixWidth = getWidth() - 2 * MARGIN; int pixHeight = getHeight() - 2 * MARGIN; // painting data points if (colorColumn >= 0) { Iterator iterator = dataTable.iterator(); // Finding Scalevalue for Color Column double minColorValue = Double.POSITIVE_INFINITY; double maxColorValue = Double.NEGATIVE_INFINITY; while (iterator.hasNext()) { double value = ((DataTableRow) iterator.next()).getValue(colorColumn); minColorValue = MathFunctions.robustMin(minColorValue, value); maxColorValue = MathFunctions.robustMax(maxColorValue, value); } // remember point positions if (!examplesApplied) { int[][] exampleCount = new int[dimensions[0]][dimensions[1]]; exampleCoordinates.clear(); iterator = dataTable.iterator(); int index = 0; while (iterator.hasNext()) { DataTableRow row = (DataTableRow) iterator.next(); // deleting special attributes of a row (ID, Class) double[] example = getDoubleArrayFromRow(row, dataTable); int[] coords = net.apply(example); exampleCoordinates.add(new ExamplePlotterPoint(index, coords[0], coords[1])); exampleCount[coords[0]][coords[1]]++; index++; } examplesApplied = true; // painting if already applied } // draw points double fieldWidth = pixWidth / (double)dimensions[0]; double fieldHeight = pixHeight / (double)dimensions[1]; Iterator<ExamplePlotterPoint> exampleIterator = exampleCoordinates.iterator(); RandomGenerator random = RandomGenerator.getRandomGenerator(2001); while (exampleIterator.hasNext()) { ExamplePlotterPoint point = exampleIterator.next(); double color = 1.0d; Color borderColor = Color.BLACK; if (coloredPoints) { color = getPointColorValue(this.dataTable, dataTable.getRow(point.getDataTableIndex()), colorColumn, minColorValue, maxColorValue); borderColor = getPointBorderColor(this.dataTable, dataTable.getRow(point.getDataTableIndex()), colorColumn); } double pertX = 0.0d; double pertY = 0.0d; if (jitterAmount > 0) { pertX = random.nextDoubleInRange(-fieldWidth / 2.0d, fieldWidth / 2.0d) * (jitterAmount / 50.0d); pertY = random.nextDoubleInRange(-fieldHeight / 2.0d, fieldHeight / 2.0d) * (jitterAmount / 50.0d); } point.setCurrentPertubatedX((int)(MARGIN + pertX + point.getX() * fieldWidth + fieldWidth / 2.0d)); point.setCurrentPertubatedY((int)(MARGIN + pertY + point.getY() * fieldHeight + fieldHeight / 2.0d)); drawPoint(g, point.getCurrentPertubatedX(), point.getCurrentPertubatedY(), getPointColor(color), borderColor); } } } public void setDataTable(DataTable dataTable) { super.setDataTable(dataTable); this.dataTable = dataTable; } /** Returns true. */ public boolean canHandleJitter() { return true; } /** Sets the level of jitter and initiates a repaint. */ public void setJitter(int jitter) { this.jitterAmount = jitter; repaint(); } public void prepareSOM(DataTable dataTable, double adaptationRadius, int trainRounds, boolean threadMode) { // reseting Data already applied flag examplesApplied = false; // generating data for SOM int dataDimension = 0; synchronized (dataTable) { Iterator iterator = dataTable.iterator(); dataDimension = dataTable.getNumberOfColumns() - dataTable.getNumberOfSpecialColumns(); iterator = dataTable.iterator(); while (iterator.hasNext()) { data.addData(getDoubleArrayFromRow((DataTableRow) iterator.next(), dataTable)); } } // generating SOM net = new KohonenNet(data); RitterAdaptation adaptationFunction = new RitterAdaptation(); adaptationFunction.setAdaptationRadiusStart(adaptationRadius); adaptationFunction.setLearnRateStart(0.8); net.setAdaptationFunction(adaptationFunction); net.init(dataDimension, dimensions, false); net.setTrainingRounds(trainRounds); // train SOM if (threadMode) { // registering this as ProgressListener net.addProgressListener(this); Thread trainThread = new Thread() { public void run() { net.train(); } }; trainThread.start(); } else { net.train(); createMatrices(); try { // necessary for preventing graphical errors in reporting Thread.sleep(1000); } catch (InterruptedException e) { // do nothing } } } protected void createMatrices() { uMatrix = getUMatrix(net, dimensions); pMatrix = getPMatrix(net, data, dimensions); // searching maximum distance between nodes maxU = 0; for (int i = 0; i < dimensions[0]; i++) { for (int j = 0; j < dimensions[1]; j++) { maxU = Math.max(maxU, uMatrix[i][j]); } } // searching maximum density and mean density maxP = 0; double meanP = 0; for (int i = 0; i < dimensions[0]; i++) { for (int j = 0; j < dimensions[1]; j++) { maxP = (maxP < pMatrix[i][j]) ? pMatrix[i][j] : maxP; meanP += pMatrix[i][j]; } } meanP /= dimensions[0] * dimensions[1]; // calculating u* Matrix uStarMatrix = getUStarMatrix(uMatrix, pMatrix, meanP, maxP, dimensions); // searching maximum of U* matrix maxUStar = 0; for (int i = 0; i < dimensions[0]; i++) { for (int j = 0; j < dimensions[1]; j++) { maxUStar = Math.max(maxUStar, uStarMatrix[i][j]); } } // create background image recalculateBackgroundImage(); this.show = true; } protected void recalculateBackgroundImage() { Vector<double[][]> printMatrix = new Vector<double[][]>(); double[] printScale = { maxU, maxP, maxUStar }; printMatrix.add(uMatrix); printMatrix.add(pMatrix); printMatrix.add(uStarMatrix); // painting Matrix int image_width = (IMAGE_WIDTH / dimensions[0]) * dimensions[0]; int image_height = (IMAGE_HEIGHT / dimensions[1]) * dimensions[1]; this.image = new BufferedImage(image_width, image_height, BufferedImage.TYPE_INT_RGB); int width = image_width / dimensions[0]; int height = image_height / dimensions[1]; for (int i = 0; i < dimensions[0]; i++) { for (int j = 0; j < dimensions[1]; j++) { interpolateRect(image, width * i, height * j, width, height, printMatrix.elementAt(showMatrix), i, j, printScale[showMatrix], colorizer[showColor]); } } } protected void interpolateRect(BufferedImage image, int posX, int posY, double width, double height, double[][] matrix, int matrixX, int matrixY, double colorScale, SOMMatrixColorizer colorizer) { // top-left if (matrix != null) { double p11 = matrix[matrixX][matrixY]; double p21 = matrix[(matrixX + 1) % this.dimensions[0]][matrixY]; double p12 = matrix[matrixX][(matrixY + 1) % this.dimensions[1]]; double p22 = matrix[(matrixX + 1) % this.dimensions[0]][(matrixY + 1) % this.dimensions[1]]; for (int i = 0; i < width; i+=1) { for (int j = 0; j < height; j+=1) { double interpolatedValue = (( p11 * (width - i) * (height - j)) + ( p21 * i * (height - j)) + ( p12 * (width - i) * j) + ( p22 * i * j)) / (height * width); double colorValue = interpolatedValue / colorScale; //colorValue = Math.min(1.0d, Math.max(0.0d, colorValue)); int rgbColor = colorizer.getPointColor(colorValue).getRGB(); image.setRGB(posX + i, posY + j, rgbColor); } } } } private double[] getDoubleArrayFromRow(DataTableRow row, DataTable table) { double[] doubleRow = new double[table.getNumberOfColumns() - table.getNumberOfSpecialColumns()]; int index = 0; for (int i = 0; i < row.getNumberOfValues(); i++) { if (!table.isSpecial(i)) { doubleRow[index] = row.getValue(i); index++; } } return doubleRow; } private double[][] getUMatrix(KohonenNet net, int[] dimensions) { double[][] uMatrix = new double[dimensions[0]][dimensions[1]]; // getting distances between nodes for (int i = 0; i < dimensions[0]; i++) { for (int j = 0; j < dimensions[1]; j++) { uMatrix[i][j] = net.getNodeDistance(net.getIndexOfCoordinates(new int[] { i, j })); } } return uMatrix; } private double[][] getPMatrix(KohonenNet net, RandomDataContainer data, int[] dimensions) { // calculating real paretoradius int n = data.countData(); double optimalMedian = 0.2013 * n; double estimatedRadius = 0; // calculating distances between every example double[] distances = new double[n * n]; for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { distances[i * n + j] = net.getDistance(data.get(i), data.get(j)); } } Arrays.sort(distances); double percentilSetDifference = Double.POSITIVE_INFINITY; // finding percentil, closest to paretoradius double radius; for (int percentil = 0; percentil < 100; percentil++) { int[] nn = new int[n]; radius = distances[(int) Math.round(((double) (percentil * n * n)) / 100)]; for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { if (net.getDistance(data.get(i), data.get(j)) <= radius) { nn[i]++; } } } Arrays.sort(nn); int currentMedian = nn[n / 2] - 1; //point himself is no real neighbour, but always nearest point if (Math.abs(currentMedian - optimalMedian) <= percentilSetDifference) { percentilSetDifference = Math.abs(currentMedian - optimalMedian); } else { estimatedRadius = radius; break; } } // generating P Matrix double[][] pMatrix = new double[dimensions[0]][dimensions[1]]; for (int i = 0; i < dimensions[0]; i++) { for (int j = 0; j < dimensions[1]; j++) { double nodeWeight[] = net.getNodeWeights(new int[] { i, j }); int neighbours = 0; for (int x = 0; x < n; x++) { if (net.getDistance(data.get(x), nodeWeight) < estimatedRadius) { neighbours++; } } pMatrix[i][j] = ((double) neighbours) / n; } } return pMatrix; } private double[][] getUStarMatrix(double[][] uMatrix, double[][] pMatrix, double meanP, double maxP, int[] dimensions) { double[][] uStarMatrix = new double[dimensions[0]][dimensions[1]]; for (int i = 0; i < dimensions[0]; i++) { for (int j = 0; j < dimensions[1]; j++) { uStarMatrix[i][j] = uMatrix[i][j] * (((pMatrix[i][j] - meanP) / (meanP - maxP)) + 1); } } return uStarMatrix; } public JComponent getOptionsComponent(int index) { switch (index) { case 0: JLabel label = new JLabel("Matrix"); label.setToolTipText("Select the matrix type which should be visualized."); return label; case 1: return matrixSelection; case 2: label = new JLabel("Style"); label.setToolTipText("Select the color scheme used for the visualization of the matrix values."); return label; case 3: return colorSelection; case 4: JPanel dimensionLabelPanel = new JPanel(); dimensionLabelPanel.setToolTipText("Set the dimensions of the Kohonen net."); dimensionLabelPanel.setLayout(new GridLayout()); dimensionLabelPanel.add(new JLabel("Net Width")); dimensionLabelPanel.add(new JLabel("Net Height")); return dimensionLabelPanel; case 5: JPanel dimensionPanel = new JPanel(); dimensionPanel.setLayout(new GridLayout()); dimensionX.setToolTipText("Set the dimensions of the Kohonen net."); dimensionY.setToolTipText("Set the dimensions of the Kohonen net."); dimensionPanel.add(dimensionX); dimensionPanel.add(dimensionY); return dimensionPanel; case 6: JPanel roundPanel = new JPanel(); roundPanel.setToolTipText("Set the number of training rounds of the Kohonen net."); roundSelection.setToolTipText("Set the number of training rounds of the Kohonen net."); roundPanel.setLayout(new GridLayout()); roundPanel.add(new JLabel("Training Rounds")); roundPanel.add(roundSelection); return roundPanel; case 7: JPanel radiusPanel = new JPanel(); radiusPanel.setToolTipText("Set the adaptation radius of the Kohonen net."); radiusSelection.setToolTipText("Set the adaptation radius of the Kohonen net."); radiusPanel.setLayout(new GridLayout()); radiusPanel.add(new JLabel("Adaptation Radius")); radiusPanel.add(radiusSelection); return radiusPanel; case 8: return progressBar; case 9: return approveButton; } return null; } public void setPlotColumn(int column, boolean plot) { if (plot) { colorColumn = column; repaint(); } } public boolean getPlotColumn(int dimension) { if (dimension == colorColumn) { return true; } return false; } public String getPlotName() { return "Point Color"; } public void setProgress(int value) { progressBar.setValue(value); } public void progressFinished() { createMatrices(); setProgress(100); repaint(); } public String getIdForPos(int x, int y) { if (this.show) { ExamplePlotterPoint point = getPlotterPointForPos(x,y); if (point != null) { return dataTable.getRow(point.getDataTableIndex()).getId(); } } return(null); } private ExamplePlotterPoint getPlotterPointForPos(int x, int y) { Iterator<ExamplePlotterPoint> exampleIterator = exampleCoordinates.iterator(); while (exampleIterator.hasNext()) { ExamplePlotterPoint point = exampleIterator.next(); if (point.contains(x,y)) { return point; } } return null; } /** Sets the mouse position in the shown data space. */ public void setMousePosInDataSpace(int x, int y) { if (show) { ExamplePlotterPoint point = getPlotterPointForPos(x, y); if (point != null) { String id = dataTable.getRow(point.getDataTableIndex()).getId(); if (id != null) { setToolTip(id, point.getCurrentPertubatedX(), point.getCurrentPertubatedY()); } else { setToolTip(null, 0.0d, 0.0d); } } else { setToolTip(null, 0.0d, 0.0d); } } } private void setToolTip(String toolTip, double x, double y) { this.currentToolTip = toolTip; this.toolTipX = x; this.toolTipY = y; repaint(); } protected void drawToolTip(Graphics2D g) { if (currentToolTip != null) { g.setFont(LABEL_FONT); Rectangle2D stringBounds = LABEL_FONT.getStringBounds(currentToolTip, g.getFontRenderContext()); g.setColor(TOOLTIP_COLOR); Rectangle2D bg = new Rectangle2D.Double((toolTipX) - stringBounds.getWidth() - 15, (toolTipY - (stringBounds.getHeight() / 2)), stringBounds.getWidth() + 6, Math.abs(stringBounds.getHeight()) + 4); g.fill(bg); g.setColor(Color.black); g.draw(bg); g.drawString(currentToolTip, (float) ((toolTipX ) - stringBounds.getWidth() - 12 ), (float) ((toolTipY + stringBounds.getHeight() * 0.5) + 1)); } } }