/* * Joinery -- Data frames for Java * Copyright (c) 2014, 2015 IBM Corp. * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ package joinery.impl; import java.awt.Color; import java.awt.Container; import java.awt.GridLayout; import java.util.ArrayList; import java.util.Arrays; import java.util.Calendar; import java.util.Date; import java.util.EnumSet; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import javax.swing.JFrame; import javax.swing.JScrollPane; import javax.swing.JTable; import javax.swing.SwingUtilities; import javax.swing.table.AbstractTableModel; import org.apache.commons.math3.stat.regression.SimpleRegression; import com.xeiam.xchart.Chart; import com.xeiam.xchart.ChartBuilder; import com.xeiam.xchart.Series; import com.xeiam.xchart.SeriesLineStyle; import com.xeiam.xchart.SeriesMarker; import com.xeiam.xchart.StyleManager.ChartType; import com.xeiam.xchart.XChartPanel; import joinery.DataFrame; import joinery.DataFrame.PlotType; public class Display { public static <C extends Container, V> C draw(final DataFrame<V> df, final C container, final PlotType type) { final List<XChartPanel> panels = new LinkedList<>(); final DataFrame<Number> numeric = df.numeric().fillna(0); final int rows = (int)Math.ceil(Math.sqrt(numeric.size())); final int cols = numeric.size() / rows + 1; final List<Object> xdata = new ArrayList<>(df.length()); final Iterator<Object> it = df.index().iterator(); for (int i = 0; i < df.length(); i++) { final Object value = it.hasNext() ? it.next(): i; if (value instanceof Number || value instanceof Date) { xdata.add(value); } else if (PlotType.BAR.equals(type)) { xdata.add(String.valueOf(value)); } else { xdata.add(i); } } if (EnumSet.of(PlotType.GRID, PlotType.GRID_WITH_TREND).contains(type)) { for (final Object col : numeric.columns()) { final Chart chart = new ChartBuilder() .chartType(chartType(type)) .width(800 / cols) .height(800 / cols) .title(String.valueOf(col)) .build(); final Series series = chart.addSeries(String.valueOf(col), xdata, numeric.col(col)); if (type == PlotType.GRID_WITH_TREND) { addTrend(chart, series, xdata); series.setLineStyle(SeriesLineStyle.NONE); } chart.getStyleManager().setLegendVisible(false); chart.getStyleManager().setDatePattern(dateFormat(xdata)); panels.add(new XChartPanel(chart)); } } else { final Chart chart = new ChartBuilder() .chartType(chartType(type)) .build(); chart.getStyleManager().setDatePattern(dateFormat(xdata)); switch (type) { case SCATTER: case SCATTER_WITH_TREND: case LINE_AND_POINTS: break; default: chart.getStyleManager().setMarkerSize(0); break; } for (final Object col : numeric.columns()) { final Series series = chart.addSeries(String.valueOf(col), xdata, numeric.col(col)); if (type == PlotType.SCATTER_WITH_TREND) { addTrend(chart, series, xdata); series.setLineStyle(SeriesLineStyle.NONE); } } panels.add(new XChartPanel(chart)); } if (panels.size() > 1) { container.setLayout(new GridLayout(rows, cols)); } for (final XChartPanel p : panels) { container.add(p); } return container; } public static <V> void plot(final DataFrame<V> df, final PlotType type) { SwingUtilities.invokeLater(new Runnable() { @Override public void run() { final JFrame frame = draw(df, new JFrame(title(df)), type); frame.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE); frame.pack(); frame.setVisible(true); } }); } public static <V> void show(final DataFrame<V> df) { final List<Object> columns = new ArrayList<>(df.columns()); final List<Class<?>> types = df.types(); SwingUtilities.invokeLater(new Runnable() { @Override public void run() { final JFrame frame = new JFrame(title(df)); final JTable table = new JTable( new AbstractTableModel() { private static final long serialVersionUID = 1L; @Override public int getRowCount() { return df.length(); } @Override public int getColumnCount() { return df.size(); } @Override public Object getValueAt(final int row, final int col) { return df.get(row, col); } @Override public String getColumnName(final int col) { return String.valueOf(columns.get(col)); } @Override public Class<?> getColumnClass(final int col) { return types.get(col); } } ); table.setAutoResizeMode(JTable.AUTO_RESIZE_OFF); frame.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE); frame.add(new JScrollPane(table)); frame.pack(); frame.setVisible(true); } }); } private static ChartType chartType(final PlotType type) { switch (type) { case AREA: return ChartType.Area; case BAR: return ChartType.Bar; case GRID: case SCATTER: return ChartType.Scatter; case SCATTER_WITH_TREND: case GRID_WITH_TREND: case LINE: default: return ChartType.Line; } } private static final String title(final DataFrame<?> df) { return String.format( "%s (%d rows x %d columns)", df.getClass().getCanonicalName(), df.length(), df.size() ); } private static final String dateFormat(final List<Object> xdata) { final int[] fields = new int[] { Calendar.YEAR, Calendar.MONTH, Calendar.DAY_OF_MONTH, Calendar.HOUR_OF_DAY, Calendar.MINUTE, Calendar.SECOND }; final String[] formats = new String[] { " yyy", "-MMM", "-d", " H", ":mm", ":ss" }; final Calendar c1 = Calendar.getInstance(), c2 = Calendar.getInstance(); if (!xdata.isEmpty() && xdata.get(0) instanceof Date) { String format = ""; int first = 0, last = 0; c1.setTime(Date.class.cast(xdata.get(0))); // iterate over all x-axis values comparing dates for (int i = 1; i < xdata.size(); i++) { // early exit for non-date elements if (!(xdata.get(i) instanceof Date)) return formats[0].substring(1); c2.setTime(Date.class.cast(xdata.get(i))); // check which components differ, those are the fields to output for (int j = 1; j < fields.length; j++) { if (c1.get(fields[j]) != c2.get(fields[j])) { first = Math.max(j - 1, first); last = Math.max(j, last); } } } // construct a format string for the fields that differ for (int i = first; i <= last && i < formats.length; i++) { format += format.isEmpty() ? formats[i].substring(1) : formats[i]; } return format; } return formats[0].substring(1); } private static void addTrend(final Chart chart, final Series series, final List<Object> xdata) { final SimpleRegression model = new SimpleRegression(); final Iterator<? extends Number> y = series.getYData().iterator(); for (int x = 0; y.hasNext(); x++) { model.addData(x, y.next().doubleValue()); } final Color mc = series.getMarkerColor(); final Color c = new Color(mc.getRed(), mc.getGreen(), mc.getBlue(), 0x60); final Series trend = chart.addSeries(series.getName() + " (trend)", Arrays.asList(xdata.get(0), xdata.get(xdata.size() - 1)), Arrays.asList(model.predict(0), model.predict(xdata.size() - 1)) ); trend.setLineColor(c); trend.setMarker(SeriesMarker.NONE); } }