/* * RapidMiner * * Copyright (C) 2001-2011 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.gui.plotter.charts; import java.awt.BasicStroke; import java.awt.Color; import java.awt.Stroke; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.Map.Entry; import java.util.TreeMap; import javax.swing.JComponent; import org.jfree.chart.ChartFactory; import org.jfree.chart.JFreeChart; import org.jfree.chart.axis.CategoryLabelPositions; import org.jfree.chart.axis.DateAxis; import org.jfree.chart.axis.NumberAxis; import org.jfree.chart.block.BlockBorder; import org.jfree.chart.plot.CategoryPlot; import org.jfree.chart.plot.Plot; import org.jfree.chart.plot.PlotOrientation; import org.jfree.chart.plot.XYPlot; import org.jfree.chart.renderer.category.BarRenderer; import org.jfree.chart.renderer.xy.DeviationRenderer; import org.jfree.chart.title.LegendTitle; import org.jfree.data.Range; import org.jfree.data.category.CategoryDataset; import org.jfree.data.category.DefaultCategoryDataset; import org.jfree.data.xy.XYDataset; import org.jfree.data.xy.XYSeries; import org.jfree.data.xy.XYSeriesCollection; import org.jfree.ui.HorizontalAlignment; import org.jfree.ui.RectangleEdge; import com.rapidminer.datatable.DataTable; import com.rapidminer.datatable.DataTableExampleSetAdapter; import com.rapidminer.example.Attribute; import com.rapidminer.example.ExampleSet; import com.rapidminer.example.Tools; import com.rapidminer.gui.plotter.PlotterConfigurationModel; import com.rapidminer.gui.plotter.RangeablePlotterAdapter; import com.rapidminer.operator.MissingIOObjectException; import com.rapidminer.operator.OperatorCreationException; import com.rapidminer.operator.OperatorException; import com.rapidminer.operator.learner.bayes.DistributionModel; import com.rapidminer.operator.learner.bayes.NaiveBayes; import com.rapidminer.tools.LogService; import com.rapidminer.tools.OperatorService; import com.rapidminer.tools.math.distribution.ContinuousDistribution; import com.rapidminer.tools.math.distribution.DiscreteDistribution; /** * This plotter can be used in order to plot a distribution model like the one * which can be delivered by NaiveBayes. * * @author Sebastian Land, Ingo Mierswa, Tobias Malbrecht */ public class DistributionPlotter extends RangeablePlotterAdapter { public static final String RANGE_AXIS_NAME = "Density"; public static final String MODEL_DOMAIN_AXIS_NAME = "Value"; private static final long serialVersionUID = 2923008541302883925L; private static final int NUMBER_OF_STEPS = 300; private int plotColumn = -1; private int groupColumn = -1; private transient DistributionModel model; private transient HashMap<String, Integer> dataTableModelColumnMap = new HashMap<String, Integer>(); private transient DataTable dataTable; private boolean createFromModel = false; public DistributionPlotter(PlotterConfigurationModel settings) { super(settings); } public DistributionPlotter(PlotterConfigurationModel settings, DistributionModel model) { this(settings); this.model = model; this.createFromModel = true; this.plotColumn = 0; updatePlotter(); } public DistributionPlotter(PlotterConfigurationModel settings, DataTable dataTable) { this(settings); setDataTable(dataTable); } // @Override // public JComponent getRenderComponent() { // return panel; // } @Override public void prepareRendering() { super.prepareRendering(); updatePlotter(); } @Override public void finishRendering() { super.finishRendering(); updatePlotter(); } public void preparePlots() { if (!createFromModel) { if ((groupColumn >= 0) && (plotColumn >= 0) && (groupColumn != plotColumn)) { ExampleSet wrappedExampleSet = DataTableExampleSetAdapter.createExampleSetFromDataTable(this.dataTable); Attribute[] attributes = Tools.createRegularAttributeArray(wrappedExampleSet); Attribute label = attributes[groupColumn]; if (label.isNominal()) { wrappedExampleSet.getAttributes().setLabel(label); try { NaiveBayes modelLearner = OperatorService.createOperator(NaiveBayes.class); this.model = (DistributionModel) modelLearner.doWork(wrappedExampleSet); // updating column map dataTableModelColumnMap.clear(); int modelColumn = 0; for (Attribute attribute : wrappedExampleSet.getAttributes()) { dataTableModelColumnMap.put(attribute.getName(), modelColumn); modelColumn++; } } catch (OperatorCreationException e) { LogService.getGlobal().logWarning("Cannot create distribution model generator. Skip plot..."); } catch (MissingIOObjectException e) { LogService.getGlobal().logWarning("No distribution model was created from data. Skip plot..."); } catch (OperatorException e) { LogService.getGlobal().logWarning("Error during creation of distribution model. Skip plot..."); } } } } } /** * This method translates the plotColumn selected from the original data table * to an attribute index of the model. * They might differ, because during model construction the label is shifted from * it's original position to the end, causing a shift of all subsequent attributes to left. */ private int translateToModelColumn(int plotColumn) { if (!createFromModel) return dataTableModelColumnMap.get(dataTable.getColumnName(plotColumn)); return plotColumn; } @Override public void updatePlotter() { preparePlots(); JFreeChart chart = null; if (!createFromModel && ((groupColumn < 0) || (plotColumn < 0))) { CategoryDataset dataset = new DefaultCategoryDataset(); chart = ChartFactory.createBarChart(null, // chart title "Not defined", // x axis label RANGE_AXIS_NAME, // y axis label dataset, // data PlotOrientation.VERTICAL, true, // include legend true, // tooltips false // urls ); } else { try { if (model.isDiscrete(translateToModelColumn(plotColumn))) { chart = createNominalChart(); } else { chart = createNumericalChart(); } } catch (Exception e) { // do nothing - just do not draw the chart } } if (chart != null) { chart.setBackgroundPaint(Color.white); // get a reference to the plot for further customization... Plot commonPlot = chart.getPlot(); commonPlot.setBackgroundPaint(Color.WHITE); if (commonPlot instanceof XYPlot) { XYPlot plot = (XYPlot) commonPlot; plot.setDomainGridlinePaint(Color.LIGHT_GRAY); plot.setRangeGridlinePaint(Color.LIGHT_GRAY); // domain axis if (dataTable != null) { if ((dataTable.isDate(plotColumn)) || (dataTable.isDateTime(plotColumn))) { DateAxis domainAxis = new DateAxis(dataTable.getColumnName(plotColumn)); domainAxis.setTimeZone(com.rapidminer.tools.Tools.getPreferredTimeZone()); plot.setDomainAxis(domainAxis); } else { NumberAxis numberAxis = new NumberAxis(dataTable.getColumnName(plotColumn)); plot.setDomainAxis(numberAxis); } } plot.getDomainAxis().setLabelFont(LABEL_FONT_BOLD); plot.getDomainAxis().setTickLabelFont(LABEL_FONT); plot.getRangeAxis().setLabelFont(LABEL_FONT_BOLD); plot.getRangeAxis().setTickLabelFont(LABEL_FONT); // ranging if (dataTable != null) { Range range = getRangeForDimension(plotColumn); if (range != null) plot.getDomainAxis().setRange(range, true, false); range = getRangeForName(RANGE_AXIS_NAME); if (range != null) plot.getRangeAxis().setRange(range, true, false); } // rotate labels if (isLabelRotating()) { plot.getDomainAxis().setTickLabelsVisible(true); plot.getDomainAxis().setVerticalTickLabels(true); } } else if (commonPlot instanceof CategoryPlot) { CategoryPlot plot = (CategoryPlot) commonPlot; plot.setDomainGridlinePaint(Color.LIGHT_GRAY); plot.setRangeGridlinePaint(Color.LIGHT_GRAY); plot.getRangeAxis().setLabelFont(LABEL_FONT_BOLD); plot.getRangeAxis().setTickLabelFont(LABEL_FONT); plot.getDomainAxis().setLabelFont(LABEL_FONT_BOLD); plot.getDomainAxis().setTickLabelFont(LABEL_FONT); } // legend settings LegendTitle legend = chart.getLegend(); if (legend != null) { legend.setPosition(RectangleEdge.TOP); legend.setFrame(BlockBorder.NONE); legend.setHorizontalAlignment(HorizontalAlignment.LEFT); legend.setItemFont(LABEL_FONT); } AbstractChartPanel panel = getPlotterPanel(); // Chart Panel Settings if (panel == null) { panel = createPanel(chart); } else panel.setChart(chart); // if (this.panel instanceof AbstractChartPanel) { // this.panel.setChart(chart); // } else { // this.panel = new AbstractChartPanel(chart, getWidth(), getHeight() - MARGIN); // scrollablePlotterPanel.add(panel, BorderLayout.CENTER); // final ChartPanelShiftController controller = new ChartPanelShiftController(panel); // this.panel.addMouseListener(controller); // this.panel.addMouseMotionListener(controller); // } // ATTENTION: WITHOUT THIS WE GET SEVERE MEMORY LEAKS!!! panel.getChartRenderingInfo().setEntityCollection(null); } } private XYDataset createNumericalDataSet() { XYSeriesCollection dataSet = new XYSeriesCollection(); int translatedPlotColumn = translateToModelColumn(plotColumn); double start = model.getLowerBound(translatedPlotColumn); double end = model.getUpperBound(translatedPlotColumn); double stepSize = (end - start) / (NUMBER_OF_STEPS - 1); for (int classIndex : model.getClassIndices()) { XYSeries series = new XYSeries(model.getClassName(classIndex)); ContinuousDistribution distribution = (ContinuousDistribution) model.getDistribution(classIndex, translatedPlotColumn); for (double currentValue = start; currentValue < end; currentValue += stepSize) { double probability = distribution.getProbability(currentValue); if (!Double.isNaN(probability)) { series.add(currentValue, distribution.getProbability(currentValue)); } } dataSet.addSeries(series); } return dataSet; } private JFreeChart createNumericalChart() { JFreeChart chart; XYDataset dataset = createNumericalDataSet(); // create the chart... String domainName = dataTable == null ? MODEL_DOMAIN_AXIS_NAME : dataTable.getColumnName(plotColumn); chart = ChartFactory.createXYLineChart(null, // chart title domainName, // x axis label RANGE_AXIS_NAME, // y axis label dataset, // data PlotOrientation.VERTICAL, true, // include legend true, // tooltips false // urls ); DeviationRenderer renderer = new DeviationRenderer(true, false); Stroke stroke = new BasicStroke(2.0f, BasicStroke.CAP_ROUND, BasicStroke.JOIN_ROUND); if (dataset.getSeriesCount() == 1) { renderer.setSeriesStroke(0, stroke); renderer.setSeriesPaint(0, Color.RED); renderer.setSeriesFillPaint(0, Color.RED); } else { for (int i = 0; i < dataset.getSeriesCount(); i++) { renderer.setSeriesStroke(i, stroke); Color color = getColorProvider().getPointColor((double) i / (double) (dataset.getSeriesCount() - 1)); renderer.setSeriesPaint(i, color); renderer.setSeriesFillPaint(i, color); } } renderer.setAlpha(0.12f); XYPlot plot = (XYPlot) chart.getPlot(); plot.setRenderer(renderer); return chart; } private JFreeChart createNominalChart() { JFreeChart chart; CategoryDataset dataset = createNominalDataSet(); // create the chart... String domainName = dataTable == null ? MODEL_DOMAIN_AXIS_NAME : dataTable.getColumnName(plotColumn); chart = ChartFactory.createBarChart(null, // chart title domainName, // x axis label RANGE_AXIS_NAME, // y axis label dataset, // data PlotOrientation.VERTICAL, true, // include legend true, // tooltips false // urls ); CategoryPlot plot = chart.getCategoryPlot(); BarRenderer renderer = new BarRenderer(); if (dataset.getRowCount() == 1) { renderer.setSeriesPaint(0, Color.RED); renderer.setSeriesFillPaint(0, Color.RED); } else { for (int i = 0; i < dataset.getRowCount(); i++) { Color color = getColorProvider().getPointColor((double) i / (double) (dataset.getRowCount() - 1)); renderer.setSeriesPaint(i, color); renderer.setSeriesFillPaint(i, color); } } renderer.setBarPainter(new RapidBarPainter()); renderer.setDrawBarOutline(true); plot.setRenderer(renderer); // rotate labels if (isLabelRotating()) { plot.getDomainAxis().setTickLabelsVisible(true); plot.getDomainAxis().setCategoryLabelPositions(CategoryLabelPositions.createUpRotationLabelPositions(Math.PI / 2.0d)); } return chart; } private CategoryDataset createNominalDataSet() { DefaultCategoryDataset dataset = new DefaultCategoryDataset(); for (Integer classIndex : model.getClassIndices()) { DiscreteDistribution distribution = (DiscreteDistribution) model.getDistribution(classIndex, translateToModelColumn(plotColumn)); String labelName = model.getClassName(classIndex); // sort values by name TreeMap<String, Double> valueMap = new TreeMap<String, Double>(); for (Double value : distribution.getValues()) { String valueName; if (Double.isNaN(value)) { valueName = "Unknown"; } else { valueName = distribution.mapValue(value); } valueMap.put(valueName, value); } for (Entry<String, Double> entry : valueMap.entrySet()) { dataset.addValue(distribution.getProbability(entry.getValue()), labelName, entry.getKey()); } } return dataset; } @Override public void setPlotColumn(int column, boolean plot) { this.plotColumn = column; updatePlotter(); } @Override public boolean getPlotColumn(int column) { return (column == this.plotColumn); } @Override public String getPlotName() { return "Plot Column"; } @Override public int getNumberOfAxes() { return 1; } @Override public void setAxis(int index, int dimension) { if (groupColumn != dimension) { groupColumn = dimension; updatePlotter(); } } @Override public int getAxis(int index) { return groupColumn; } @Override public String getAxisName(int axis) { return "Class Column"; } @Override public JComponent getOptionsComponent(int index) { if (index == 0) { return getRotateLabelComponent(); } else { return null; } } @Override public String getPlotterName() { return PlotterConfigurationModel.DISTRIBUTION_PLOT; } @Override public void dataTableSet() { this.dataTable = getDataTable(); if (!createFromModel) { updatePlotter(); } } @Override public Collection<String> resolveXAxis(int axisIndex) { if (dataTable != null && plotColumn != -1) return Collections.singletonList(dataTable.getColumnName(plotColumn)); else if (createFromModel) return Collections.singletonList(MODEL_DOMAIN_AXIS_NAME); else return Collections.emptyList(); } @Override public Collection<String> resolveYAxis(int axisIndex) { return Collections.singletonList(RANGE_AXIS_NAME); } }