/* * RapidMiner * * Copyright (C) 2001-2008 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.gui.plotter; import java.awt.BasicStroke; import java.awt.Color; import java.awt.Graphics; import java.awt.Graphics2D; import java.awt.Stroke; import java.awt.geom.Rectangle2D; //import javax.swing.JComponent; //import javax.swing.JLabel; //import javax.swing.JSlider; //import javax.swing.event.ChangeEvent; //import javax.swing.event.ChangeListener; import org.jfree.chart.ChartFactory; import org.jfree.chart.JFreeChart; import org.jfree.chart.axis.NumberAxis; import org.jfree.chart.block.BlockBorder; import org.jfree.chart.plot.CategoryPlot; import org.jfree.chart.plot.PlotOrientation; import org.jfree.chart.plot.XYPlot; import org.jfree.chart.renderer.category.BarRenderer; import org.jfree.chart.renderer.xy.DeviationRenderer; import org.jfree.chart.title.LegendTitle; import org.jfree.data.category.CategoryDataset; import org.jfree.data.category.DefaultCategoryDataset; import org.jfree.data.xy.XYDataset; import org.jfree.data.xy.XYSeries; import org.jfree.data.xy.XYSeriesCollection; import org.jfree.ui.HorizontalAlignment; import org.jfree.ui.RectangleEdge; import org.jfree.ui.RectangleInsets; import com.rapidminer.datatable.DataTable; import com.rapidminer.datatable.DataTableExampleSetAdapter; import com.rapidminer.example.Attribute; import com.rapidminer.example.ExampleSet; import com.rapidminer.example.Tools; import com.rapidminer.operator.MissingIOObjectException; import com.rapidminer.operator.OperatorCreationException; import com.rapidminer.operator.OperatorException; import com.rapidminer.operator.learner.bayes.DistributionModel; import com.rapidminer.operator.learner.bayes.NaiveBayes; import com.rapidminer.tools.LogService; import com.rapidminer.tools.OperatorService; import com.rapidminer.tools.math.distribution.ContinuousDistribution; import com.rapidminer.tools.math.distribution.DiscreteDistribution; /** * This plotter can be used in order to plot a distribution model * like the one which can be delivered by NaiveBayes. * * @author Sebastian Land, Ingo Mierswa, Tobias Malbrecht * @version $Id: DistributionPlotter.java,v 1.9 2008/08/07 09:01:18 tobiasmalbrecht Exp $ */ public class DistributionPlotter extends PlotterAdapter { private static final long serialVersionUID = 2923008541302883925L; private static final int NUMBER_OF_STEPS = 300; private boolean plot = true; private int plotColumn = -1; private int groupColumn = -1; private transient DistributionModel model; private transient DataTable dataTable; private boolean createFromDataTable = false; // private JSlider kernelSlider = new JSlider(1,200,1); public DistributionPlotter() { // kernelSlider.addChangeListener(new ChangeListener() { // public void stateChanged(ChangeEvent e) { // if (!kernelSlider.getValueIsAdjusting()) { // repaint(); // } // } // }); } public DistributionPlotter(DistributionModel model) { this(); this.model = model; createFromDataTable = false; this.plotColumn = 0; } public DistributionPlotter(DataTable dataTable) { this(); setDataTable(dataTable); } public void setDataTable(DataTable dataTable) { super.setDataTable(dataTable); this.dataTable = dataTable; this.createFromDataTable = true; } public void preparePlots() { if (createFromDataTable) { plot = false; if ((groupColumn >= 0) && (plotColumn >= 0) && (groupColumn != plotColumn)) { ExampleSet wrappedExampleSet = DataTableExampleSetAdapter.createExampleSetFromDataTable(this.dataTable); Attribute[] attributes = Tools.createRegularAttributeArray(wrappedExampleSet); Attribute label = attributes[groupColumn]; if (label.isNominal()) { wrappedExampleSet.getAttributes().setLabel(label); try { NaiveBayes modelLearner = (NaiveBayes)OperatorService.createOperator(NaiveBayes.class); // modelLearner.setParameter(NaiveBayes.PARAMETER_USE_KERNEL, "true"); // modelLearner.setParameter(NaiveBayes.PARAMETER_NUMBER_OF_KERNELS, kernelSlider.getValue() + ""); this.model = (DistributionModel)modelLearner.learn(wrappedExampleSet); } catch (OperatorCreationException e) { LogService.getGlobal().logWarning("Cannot create distribution model generator. Skip plot..."); } catch (MissingIOObjectException e) { LogService.getGlobal().logWarning("No distribution model was created from data. Skip plot..."); } catch (OperatorException e) { LogService.getGlobal().logWarning("Error during creation of distribution model. Skip plot..."); } plot = true; } } } } public void paintComponent(Graphics graphics) { super.paintComponent(graphics); paintComponent(graphics, getWidth(), getHeight()); } public void paintComponent(Graphics graphics, int width, int height) { preparePlots(); if (plot) { JFreeChart chart = null; try { if (model.isDiscrete(plotColumn)) { chart = createNominalChart(); } else { chart = createNumericalChart(); } } catch (Exception e) { // do nothing - just do not draw the chart } if (chart != null) { // set the background color for the chart... chart.setBackgroundPaint(Color.white); // legend settings LegendTitle legend = chart.getLegend(); if (legend != null) { legend.setPosition(RectangleEdge.TOP); legend.setFrame(BlockBorder.NONE); legend.setHorizontalAlignment(HorizontalAlignment.LEFT); } Rectangle2D drawRect = new Rectangle2D.Double(0, 0, width, height); chart.draw((Graphics2D) graphics, drawRect); } } } private XYDataset createNumericalDataSet() { XYSeriesCollection dataSet = new XYSeriesCollection(); double start = model.getLowerBound(plotColumn); double end = model.getUpperBound(plotColumn); double stepSize = (end - start) / (NUMBER_OF_STEPS - 1); for (int classIndex : model.getClassIndices()) { XYSeries series = new XYSeries(model.getClassName(classIndex)); ContinuousDistribution distribution = (ContinuousDistribution) model.getDistribution(classIndex, plotColumn); for (double currentValue = start; currentValue <= end; currentValue += stepSize) { double probability = distribution.getProbability(currentValue); if (!Double.isNaN(probability)) { series.add(currentValue, distribution.getProbability(currentValue)); } } dataSet.addSeries(series); } return dataSet; } private JFreeChart createNumericalChart() { JFreeChart chart; XYDataset dataset = createNumericalDataSet(); // create the chart... chart = ChartFactory.createXYLineChart(null, // chart title "value", // x axis label "density", // y axis label dataset, // data PlotOrientation.VERTICAL, true, // include legend true, // tooltips false // urls ); chart.setBackgroundPaint(Color.white); // get a reference to the plot for further customisation... XYPlot plot = (XYPlot) chart.getPlot(); plot.setBackgroundPaint(Color.WHITE); plot.setAxisOffset(new RectangleInsets(5.0, 5.0, 5.0, 5.0)); plot.setDomainGridlinePaint(Color.LIGHT_GRAY); plot.setRangeGridlinePaint(Color.LIGHT_GRAY); DeviationRenderer renderer = new DeviationRenderer(true, false); Stroke stroke = new BasicStroke(2.0f, BasicStroke.CAP_ROUND, BasicStroke.JOIN_ROUND); if (dataset.getSeriesCount() == 1) { renderer.setSeriesStroke(0, stroke); renderer.setSeriesPaint(0, Color.RED); renderer.setSeriesFillPaint(0, Color.RED); } else { for (int i = 0; i < dataset.getSeriesCount(); i++) { renderer.setSeriesStroke(i, stroke); Color color = getPointColor((double) i / (double) (dataset.getSeriesCount() - 1)); renderer.setSeriesPaint(i, color); renderer.setSeriesFillPaint(i, color); } } renderer.setAlpha(0.12f); plot.setRenderer(renderer); NumberAxis xAxis = (NumberAxis) plot.getDomainAxis(); xAxis.setStandardTickUnits(NumberAxis.createIntegerTickUnits()); return chart; } private JFreeChart createNominalChart() { JFreeChart chart; CategoryDataset dataset = createNominalDataSet(); // create the chart... chart = ChartFactory.createBarChart(null, // chart title "value", // x axis label "density", // y axis label dataset, // data PlotOrientation.VERTICAL, true, // include legend true, // tooltips false // urls ); CategoryPlot plot = (CategoryPlot) chart.getPlot(); BarRenderer renderer = new BarRenderer(); if (dataset.getRowCount() == 1) { renderer.setSeriesPaint(0, Color.RED); renderer.setSeriesFillPaint(0, Color.RED); } else { for (int i = 0; i < dataset.getRowCount(); i++) { Color color = getPointColor((double) i / (double) (dataset.getRowCount() - 1)); renderer.setSeriesPaint(i, color); renderer.setSeriesFillPaint(i, color); } } plot.setRenderer(renderer); return chart; } private CategoryDataset createNominalDataSet() { DefaultCategoryDataset dataset = new DefaultCategoryDataset(); for (Integer classIndex : model.getClassIndices()) { DiscreteDistribution distribution = (DiscreteDistribution) model.getDistribution(classIndex, plotColumn); String labelName = model.getClassName(classIndex); for (Double value : distribution.getValues()) { String valueName; if (Double.isNaN(value)) valueName = "unkown"; else valueName = distribution.mapValue(value); dataset.addValue(distribution.getProbability(value), labelName, valueName); } } return dataset; } public void setPlotColumn(int column, boolean plot) { this.plotColumn = column; this.plot = plot; repaint(); } public boolean getPlotColumn(int column) { return (column == this.plotColumn && plot); } public String getPlotName() { return "Plot Column:"; } public int getNumberOfAxes() { return 1; } public void setAxis(int index, int dimension) { if (groupColumn != dimension) { groupColumn = dimension; repaint(); } } public int getAxis(int index) { return groupColumn; } public String getAxisName(int axis) { return "Class Column:"; } // public JComponent getOptionsComponent(int index) { // switch (index) { //// case 0: //// JLabel label = new JLabel("Number of Kernels:"); //// label.setToolTipText("Select the number of kernels used for the estimation of the distribution of numerical attributes."); //// return label; //// case 1: //// return kernelSlider; // default: // return null; // } // } }