/* * CorrelationPanel.java * * Copyright (C) 2002-2009 Alexei Drummond and Andrew Rambaut * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.app.tracer.traces; import dr.app.gui.chart.*; import dr.inference.trace.TraceDistribution; import dr.inference.trace.TraceFactory; import dr.inference.trace.TraceList; import dr.stats.Variate; import jam.framework.Exportable; import javax.swing.*; import java.awt.*; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; /** * A panel that displays correlation plots of 2 traces * * @author Andrew Rambaut * @author Alexei Drummond * @version $Id: CorrelationPanel.java,v 1.1.1.2 2006/04/25 23:00:09 rambaut Exp $ */ public class JointDensityPanel extends JPanel implements Exportable { private ChartSetupDialog chartSetupDialog = null; private JIntervalsChart correlationChart = new JIntervalsChart(new LinearAxis(), new LinearAxis()); private JChartPanel chartPanel = new JChartPanel(correlationChart, null, "", ""); private TableScrollPane tableScrollPane = new TableScrollPane(); private JComboBox cateTableProbTypeCombo = new JComboBox(CateTableProbType.values()); private JCheckBox defaultNumberFormatCheckBox = new JCheckBox("Use default number format"); private JLabel messageLabel = new JLabel("No data loaded"); private JCheckBox sampleCheckBox = new JCheckBox("Sample only"); private JCheckBox pointsCheckBox = new JCheckBox("Draw as points"); private JCheckBox translucencyCheckBox = new JCheckBox("Use translucency"); private TraceList tl1 = null; private TraceList tl2 = null; private int traceIndex1 = -1; private int traceIndex2 = -1; private String name1; private String name2; public enum CateTableProbType { JOINT_PRO("Joint Probability"), COND_PRO_X("Conditional Prob (?|row)"), COND_PRO_Y("Conditional Prob (?|column)"), COUNT("Count"); CateTableProbType(String name) { this.name = name; } public String toString() { return name; } private final String name; } /** * Creates new CorrelationPanel */ public JointDensityPanel(final JFrame frame) { setOpaque(false); setMinimumSize(new Dimension(300, 150)); setLayout(new BorderLayout()); // add(messageLabel, BorderLayout.NORTH); // add(chartPanel, BorderLayout.CENTER); JToolBar toolBar = new JToolBar(); toolBar.setOpaque(false); toolBar.setLayout(new FlowLayout(FlowLayout.LEFT)); toolBar.setFloatable(false); JButton chartSetupButton = new JButton("Axes..."); chartSetupButton.putClientProperty( "Quaqua.Button.style", "placard" ); chartSetupButton.setFont(UIManager.getFont("SmallSystemFont")); toolBar.add(chartSetupButton); sampleCheckBox.setOpaque(false); sampleCheckBox.setFont(UIManager.getFont("SmallSystemFont")); sampleCheckBox.setSelected(true); toolBar.add(sampleCheckBox); pointsCheckBox.setOpaque(false); pointsCheckBox.setFont(UIManager.getFont("SmallSystemFont")); toolBar.add(pointsCheckBox); translucencyCheckBox.setOpaque(false); translucencyCheckBox.setFont(UIManager.getFont("SmallSystemFont")); toolBar.add(translucencyCheckBox); cateTableProbTypeCombo.setOpaque(false); cateTableProbTypeCombo.setFont(UIManager.getFont("SmallSystemFont")); toolBar.add(cateTableProbTypeCombo); defaultNumberFormatCheckBox.setOpaque(false); defaultNumberFormatCheckBox.setFont(UIManager.getFont("SmallSystemFont")); defaultNumberFormatCheckBox.setSelected(true); toolBar.add(defaultNumberFormatCheckBox); toolBar.add(new JToolBar.Separator(new Dimension(8, 8))); add(messageLabel, BorderLayout.NORTH); add(toolBar, BorderLayout.SOUTH); add(chartPanel, BorderLayout.CENTER); chartSetupButton.addActionListener( new java.awt.event.ActionListener() { public void actionPerformed(ActionEvent actionEvent) { if (chartSetupDialog == null) { chartSetupDialog = new ChartSetupDialog(frame, true, true, Axis.AT_MAJOR_TICK, Axis.AT_MAJOR_TICK, Axis.AT_MAJOR_TICK, Axis.AT_MAJOR_TICK); } chartSetupDialog.showDialog(correlationChart); validate(); repaint(); } } ); ActionListener listener = new java.awt.event.ActionListener() { public void actionPerformed(java.awt.event.ActionEvent ev) { setupChartOrTable(); } }; sampleCheckBox.addActionListener(listener); pointsCheckBox.addActionListener(listener); translucencyCheckBox.addActionListener(listener); cateTableProbTypeCombo.addActionListener(listener); defaultNumberFormatCheckBox.addActionListener(listener); } public void setCombinedTraces() { chartPanel.setXAxisTitle(""); chartPanel.setYAxisTitle(""); messageLabel.setText("Can't show correlation of combined traces"); } public void setTraces(TraceList[] traceLists, java.util.List<String> traceNames) { // correlationChart.removeAllPlots(); if (traceLists != null && traceNames != null && traceLists.length == 2 && traceNames.size() == 1) { tl1 = traceLists[0]; name1 = tl1.getName(); tl2 = traceLists[1]; name2 = tl2.getName(); traceIndex1 = tl1.getTraceIndex(traceNames.get(0)); traceIndex2 = tl2.getTraceIndex(traceNames.get(0)); name1 = name1 + " - " + tl1.getTraceName(traceIndex1); name2 = name2 + " - " + tl2.getTraceName(traceIndex2); } else if (traceLists != null && traceNames != null && traceLists.length == 1 && traceNames.size() == 2) { tl1 = traceLists[0]; tl2 = traceLists[0]; traceIndex1 = tl1.getTraceIndex(traceNames.get(0)); traceIndex2 = tl2.getTraceIndex(traceNames.get(1)); name1 = tl1.getTraceName(traceIndex1); name2 = tl2.getTraceName(traceIndex2); } else { tl1 = null; tl2 = null; } setupChartOrTable(); } private void setupChartOrTable() { correlationChart.removeAllIntervals(); if (tl1 == null || tl2 == null) { // correlationChart.removeAllPlots(); chartPanel.remove(tableScrollPane); chartPanel.setXAxisTitle(""); chartPanel.setYAxisTitle(""); messageLabel.setText("Select two statistics or traces from the table to view their correlation"); return; } TraceDistribution td1 = tl1.getDistributionStatistics(traceIndex1); TraceDistribution td2 = tl2.getDistributionStatistics(traceIndex2); if (td1 == null || td2 == null) { // correlationChart.removeAllPlots(); chartPanel.remove(tableScrollPane); chartPanel.setXAxisTitle(""); chartPanel.setYAxisTitle(""); messageLabel.setText("Waiting for analysis to complete"); return; } messageLabel.setText(""); if (td1.getTraceType() != TraceFactory.TraceType.DOUBLE && td2.getTraceType() != TraceFactory.TraceType.DOUBLE) { chartPanel.remove(correlationChart); chartPanel.add(tableScrollPane, "Table"); sampleCheckBox.setVisible(false); pointsCheckBox.setVisible(false); translucencyCheckBox.setVisible(false); cateTableProbTypeCombo.setVisible(true); defaultNumberFormatCheckBox.setVisible(true); Object[] rowNames = td1.getRange().toArray(); Object[] colNames = td2.getRange().toArray(); double[][] data = categoricalPlot(td1, td2); tableScrollPane.setTable(rowNames, colNames, data, defaultNumberFormatCheckBox.isSelected()); } else { chartPanel.remove(tableScrollPane); chartPanel.add(correlationChart, "Chart"); // correlationChart.removeAllPlots(); cateTableProbTypeCombo.setVisible(false); defaultNumberFormatCheckBox.setVisible(false); if (td1.getTraceType() == TraceFactory.TraceType.STRING) { mixedCategoricalPlot(td1, false); // isFirstTraceListNumerical sampleCheckBox.setVisible(false); pointsCheckBox.setVisible(false); translucencyCheckBox.setVisible(false); } else if (td2.getTraceType() == TraceFactory.TraceType.STRING) { mixedCategoricalPlot(td2, true); // isFirstTraceListNumerical sampleCheckBox.setVisible(false); pointsCheckBox.setVisible(false); translucencyCheckBox.setVisible(false); String swapName = name1; name1 = name2; name2 = swapName; } else { numericalPlot(td1, td2); sampleCheckBox.setVisible(true); pointsCheckBox.setVisible(true); translucencyCheckBox.setVisible(true); } } chartPanel.setXAxisTitle(name1); chartPanel.setYAxisTitle(name2); validate(); repaint(); } private void mixedCategoricalPlot(TraceDistribution td, boolean isFirstTraceListNumerical) { correlationChart.setXAxis(new DiscreteAxis(true, true)); List<String> categoryValues = td.getRange(); Map<String, TraceDistribution> categoryTdMap = new HashMap<String, TraceDistribution>(); if (categoryValues == null || categoryValues.size() < 1) return; int maxCount = Math.max(tl1.getStateCount(), tl2.getStateCount()); int minCount = Math.min(tl1.getStateCount(), tl2.getStateCount()); int sampleSize = minCount; double samples1[] = new double[sampleSize]; int k = 0; List values; if (isFirstTraceListNumerical) { values = tl1.getValues(traceIndex1); } else { values = tl2.getValues(traceIndex2); } for (int i = 0; i < sampleSize; i++) { samples1[i] = ((Number) values.get(k)).doubleValue(); k += minCount / sampleSize; } String samples2[] = new String[sampleSize]; k = 0; List values2; if (isFirstTraceListNumerical) { values2 = tl2.getValues(traceIndex2); } else { values2 = tl1.getValues(traceIndex1); } for (int i = 0; i < sampleSize; i++) { samples2[i] = values2.get(k).toString(); k += minCount / sampleSize; } // separate samples into categoryTdMap ArrayList[] sepValues = new ArrayList[categoryValues.size()]; for (int i = 0; i < categoryValues.size(); i++) { sepValues[i] = new ArrayList<Double>(); for (int j = 0; j < samples2.length; j++) { if (categoryValues.get(i).equals(samples2[j])) { sepValues[i].add(samples1[j]); } } TraceDistribution categoryTd = new TraceDistribution(sepValues[i], TraceFactory.TraceType.DOUBLE); // todo ? categoryTdMap.put(categoryValues.get(i), categoryTd); } for (String categoryValue : categoryValues) { TraceDistribution categoryTd = categoryTdMap.get(categoryValue); correlationChart.addIntervals(categoryValue, categoryTd.getMean(), categoryTd.getUpperHPD(), categoryTd.getLowerHPD(), false); } } private double[][] categoricalPlot(TraceDistribution td1, TraceDistribution td2) { List<String> rowNames = td1.getRange(); List<String> colNames = td2.getRange(); double[][] data = new double[rowNames.size()][colNames.size()]; int maxCount = Math.max(tl1.getStateCount(), tl2.getStateCount()); int minCount = Math.min(tl1.getStateCount(), tl2.getStateCount()); int sampleSize = minCount; if (sampleSize <= 0) System.err.println("sampleSize cannot be 0. sampleSize = " + sampleSize); String samples1[] = new String[sampleSize]; int k = 0; List values = tl1.getValues(traceIndex1); TraceFactory.TraceType type = tl1.getTrace(traceIndex1).getTraceType(); for (int i = 0; i < sampleSize; i++) { if (type == TraceFactory.TraceType.INTEGER) { // as Integer is stored as Double in Trace samples1[i] = Integer.toString( ((Number) values.get(k)).intValue() ); } else { samples1[i] = values.get(k).toString(); } k += minCount / sampleSize; // = 1 for non-continous vs non-continous } String samples2[] = new String[sampleSize]; k = 0; values = tl2.getValues(traceIndex2); type = tl2.getTrace(traceIndex2).getTraceType(); for (int i = 0; i < sampleSize; i++) { if (type == TraceFactory.TraceType.INTEGER) { // as Integer is stored as Double in Trace samples2[i] = Integer.toString( ((Number) values.get(k)).intValue() ); } else { samples2[i] = values.get(k).toString(); } k += minCount / sampleSize; } // calculate count for (int i = 0; i < sampleSize; i++) { if (rowNames.contains(samples1[i]) && colNames.contains(samples2[i])) { data[rowNames.indexOf(samples1[i])][colNames.indexOf(samples2[i])] += 1; } else { // System.err.println("Not find row or column name. i = " + i); } } if (cateTableProbTypeCombo.getSelectedItem() == CateTableProbType.JOINT_PRO) { for (int r = 0; r < data.length; r++) { for (int c = 0; c < data[0].length; c++) { data[r][c] = data[r][c] / sampleSize; } } } else if (cateTableProbTypeCombo.getSelectedItem() == CateTableProbType.COND_PRO_X) { for (int r = 0; r < data.length; r++) { double count = 0; for (int c = 0; c < data[0].length; c++) { count = count + data[r][c]; } for (int c = 0; c < data[0].length; c++) { if (count != 0) data[r][c] = data[r][c] / count; } } } else if (cateTableProbTypeCombo.getSelectedItem() == CateTableProbType.COND_PRO_Y) { for (int c = 0; c < data[0].length; c++) { double count = 0; for (int r = 0; r < data.length; r++) { count = count + data[r][c]; } for (int r = 0; r < data.length; r++) { if (count != 0) data[r][c] = data[r][c] / count; } } } // else COUNT return data; } private void numericalPlot(TraceDistribution td1, TraceDistribution td2) { int maxCount = Math.max(tl1.getStateCount(), tl2.getStateCount()); int minCount = Math.min(tl1.getStateCount(), tl2.getStateCount()); int sampleSize = minCount; if (sampleCheckBox.isSelected()) { if (td1.getESS() < td2.getESS()) { sampleSize = (int) td1.getESS(); } else { sampleSize = (int) td2.getESS(); } if (sampleSize < 20) { sampleSize = 20; messageLabel.setText("One of the traces has an ESS < 20 so a sample size of 20 will be used"); } if (sampleSize > 500) { messageLabel.setText("This plot has been sampled down to 500 points"); sampleSize = 500; } } int k = 0; if (td1.getTraceType() == TraceFactory.TraceType.INTEGER) { correlationChart.setXAxis(new DiscreteAxis(true, true)); } else { correlationChart.setXAxis(new LinearAxis()); } List values = tl1.getValues(traceIndex1); List<Double> samples1 = new ArrayList<Double>(); for (int i = 0; i < sampleSize; i++) { samples1.add(i, ((Number) values.get(k)).doubleValue()); k += minCount / sampleSize; } k = 0; if (td2.getTraceType() == TraceFactory.TraceType.INTEGER) { correlationChart.setYAxis(new DiscreteAxis(true, true)); } else { correlationChart.setYAxis(new LinearAxis()); } values = tl2.getValues(traceIndex2); List<Double> samples2 = new ArrayList<Double>(); for (int i = 0; i < sampleSize; i++) { samples2.add(i, ((Number) values.get(k)).doubleValue()); k += minCount / sampleSize; } ScatterPlot plot = new ScatterPlot(samples1, samples2); plot.setMarkStyle(pointsCheckBox.isSelected() ? Plot.POINT_MARK : Plot.CIRCLE_MARK, pointsCheckBox.isSelected() ? 1.0 : 3.0, new BasicStroke(2.0f, BasicStroke.CAP_ROUND, BasicStroke.JOIN_MITER), new Color(16, 16, 64, translucencyCheckBox.isSelected() ? 32 : 255), new Color(16, 16, 64, translucencyCheckBox.isSelected() ? 32 : 255)); correlationChart.addPlot(plot); } // private double[] removeNaN(double[] sample) { // List<Double> selectedValuesList = new ArrayList<Double>(); // // for (int i = 0; i < sample.length; i++) { // if (sample[i] != Double.NaN) { // selectedValuesList.add(sample[i]); // } // } // // double[] dest = new double[selectedValuesList.size()]; // for (int i = 0; i < dest.length; i++) { // dest[i] = selectedValuesList.get(i).doubleValue(); // } // // return dest; // } public JComponent getExportableComponent() { return chartPanel; } public String toString() { if (correlationChart.getPlotCount() == 0) { return "no plot available"; } StringBuffer buffer = new StringBuffer(); Plot plot = correlationChart.getPlot(0); Variate xData = plot.getXData(); Variate yData = plot.getYData(); buffer.append(chartPanel.getXAxisTitle()); buffer.append("\t"); buffer.append(chartPanel.getYAxisTitle()); buffer.append("\n"); for (int i = 0; i < xData.getCount(); i++) { buffer.append(String.valueOf(xData.get(i))); buffer.append("\t"); buffer.append(String.valueOf(yData.get(i))); buffer.append("\n"); } return buffer.toString(); } }