/* * ARX: Powerful Data Anonymization * Copyright 2012 - 2017 Fabian Prasser, Florian Kohlmayer and contributors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.deidentifier.arx.gui.view.impl.utility; import java.util.ArrayList; import java.util.List; import org.deidentifier.arx.ARXLogisticRegressionConfiguration; import org.deidentifier.arx.aggregates.StatisticsBuilderInterruptible; import org.deidentifier.arx.aggregates.StatisticsClassification; import org.deidentifier.arx.aggregates.StatisticsClassification.PrecisionRecallMatrix; import org.deidentifier.arx.gui.Controller; import org.deidentifier.arx.gui.model.ModelEvent; import org.deidentifier.arx.gui.model.ModelEvent.ModelPart; import org.deidentifier.arx.gui.resources.Resources; import org.deidentifier.arx.gui.view.SWTUtil; import org.deidentifier.arx.gui.view.impl.common.ClipboardHandlerTable; import org.deidentifier.arx.gui.view.impl.common.ComponentStatusLabelProgressProvider; import org.deidentifier.arx.gui.view.impl.common.async.Analysis; import org.deidentifier.arx.gui.view.impl.common.async.AnalysisContext; import org.deidentifier.arx.gui.view.impl.common.async.AnalysisManager; import org.eclipse.swt.SWT; import org.eclipse.swt.custom.SashForm; import org.eclipse.swt.events.ControlAdapter; import org.eclipse.swt.events.ControlEvent; import org.eclipse.swt.events.DisposeEvent; import org.eclipse.swt.events.DisposeListener; import org.eclipse.swt.graphics.Color; import org.eclipse.swt.graphics.Font; import org.eclipse.swt.graphics.FontData; import org.eclipse.swt.graphics.Point; import org.eclipse.swt.graphics.Rectangle; import org.eclipse.swt.layout.FillLayout; import org.eclipse.swt.widgets.Composite; import org.eclipse.swt.widgets.Control; import org.eclipse.swt.widgets.Display; import org.eclipse.swt.widgets.Event; import org.eclipse.swt.widgets.Listener; import org.eclipse.swt.widgets.TableColumn; import org.eclipse.swt.widgets.TableItem; import org.swtchart.Chart; import org.swtchart.IAxis; import org.swtchart.IAxisSet; import org.swtchart.ILineSeries; import org.swtchart.ILineSeries.PlotSymbolType; import org.swtchart.ISeries; import org.swtchart.ISeries.SeriesType; import org.swtchart.ISeriesSet; import org.swtchart.ITitle; import org.swtchart.Range; import de.linearbits.swt.table.DynamicTable; import de.linearbits.swt.table.DynamicTableColumn; /** * This view displays a statistics about the performance of logistic regression classifiers * * @author Fabian Prasser */ public abstract class ViewStatisticsLogisticRegression extends ViewStatistics<AnalysisContextClassification> { /** Minimal width of a category label. */ private static final int MIN_CATEGORY_WIDTH = 10; /** Internal stuff. */ private AnalysisManager manager; /** View */ private DynamicTable table; /** View */ private Composite root; /** View */ private SashForm sash; /** View */ private Chart chart; /** * Creates a new instance. * * @param parent * @param controller * @param part */ public ViewStatisticsLogisticRegression(final Composite parent, final Controller controller, final ModelPart part) { super(parent, controller, part, null, false); this.manager = new AnalysisManager(parent.getDisplay()); controller.addListener(ModelPart.SELECTED_FEATURES_OR_CLASSES, this); controller.addListener(ModelPart.DATA_TYPE, this); controller.addListener(ModelPart.SELECTED_ATTRIBUTE, this); } @Override public LayoutUtility.ViewUtilityType getType() { return LayoutUtility.ViewUtilityType.LOGISTIC_REGRESSION; } @Override public void update(ModelEvent event) { super.update(event); if (event.part == ModelPart.SELECTED_FEATURES_OR_CLASSES || event.part == ModelPart.DATA_TYPE) { if (getModel() != null && (getModel().getSelectedFeatures().isEmpty() || getModel().getSelectedClasses().isEmpty())) { doReset(); return; } else { triggerUpdate(); } } if (event.part == ModelPart.SELECTED_ATTRIBUTE) { int index = 0; for (TableItem item : table.getItems()) { if (item.getText(0).equals(super.getModel().getSelectedAttribute())) { table.select(index); if (item.getData() != null && item.getData() instanceof PrecisionRecallMatrix) { setChartSeries((PrecisionRecallMatrix) item.getData()); } return; } index++; } } } /** * Resets the chart */ private void resetChart() { if (chart != null) { chart.dispose(); } chart = new Chart(this.sash, SWT.NONE); chart.setOrientation(SWT.HORIZONTAL); // Show/Hide axis chart.addControlListener(new ControlAdapter(){ @Override public void controlResized(ControlEvent arg0) { updateCategories(); } }); // Update font FontData[] fd = chart.getFont().getFontData(); fd[0].setHeight(8); final Font font = new Font(chart.getDisplay(), fd[0]); chart.setFont(font); chart.addDisposeListener(new DisposeListener(){ public void widgetDisposed(DisposeEvent arg0) { if (font != null && !font.isDisposed()) { font.dispose(); } } }); // Update title ITitle graphTitle = chart.getTitle(); graphTitle.setText(""); //$NON-NLS-1$ graphTitle.setFont(chart.getFont()); // Set colors chart.setBackground(root.getBackground()); chart.setForeground(root.getForeground()); // OSX workaround if (System.getProperty("os.name").toLowerCase().contains("mac")){ //$NON-NLS-1$ //$NON-NLS-2$ int r = chart.getBackground().getRed()-13; int g = chart.getBackground().getGreen()-13; int b = chart.getBackground().getBlue()-13; r = r>0 ? r : 0; r = g>0 ? g : 0; r = b>0 ? b : 0; final Color background = new Color(chart.getDisplay(), r, g, b); chart.setBackground(background); chart.addDisposeListener(new DisposeListener(){ public void widgetDisposed(DisposeEvent arg0) { if (background != null && !background.isDisposed()) { background.dispose(); } } }); } // Initialize axes IAxisSet axisSet = chart.getAxisSet(); IAxis yAxis = axisSet.getYAxis(0); IAxis xAxis = axisSet.getXAxis(0); ITitle xAxisTitle = xAxis.getTitle(); xAxisTitle.setText(""); //$NON-NLS-1$ xAxis.getTitle().setFont(chart.getFont()); yAxis.getTitle().setFont(chart.getFont()); xAxis.getTick().setFont(chart.getFont()); yAxis.getTick().setFont(chart.getFont()); xAxis.getTick().setForeground(chart.getForeground()); yAxis.getTick().setForeground(chart.getForeground()); xAxis.getTitle().setForeground(chart.getForeground()); yAxis.getTitle().setForeground(chart.getForeground()); // Initialize axes ITitle yAxisTitle = yAxis.getTitle(); yAxisTitle.setText(Resources.getMessage("ViewStatisticsClassificationInput.17")); //$NON-NLS-1$ xAxisTitle.setText(Resources.getMessage("ViewStatisticsClassificationInput.14")); //$NON-NLS-1$ chart.setEnabled(false); updateCategories(); } /** * Updates the chart with a new matrix * @param matrix */ private void setChartSeries(PrecisionRecallMatrix matrix) { // Init data String[] xAxisLabels = new String[matrix.getConfidenceThresholds().length]; double[] ySeriesPrecision = new double[matrix.getConfidenceThresholds().length]; double[] ySeriesRecall = new double[matrix.getConfidenceThresholds().length]; for (int i = 0; i < xAxisLabels.length; i++) { xAxisLabels[i] = SWTUtil.getPrettyString(matrix.getConfidenceThresholds()[i] * 100d); ySeriesPrecision[i] = matrix.getPrecision()[i] * 100d; ySeriesRecall[i] = matrix.getRecall()[i] * 100d; } chart.setRedraw(false); ISeriesSet seriesSet = chart.getSeriesSet(); ILineSeries series1 = (ILineSeries) seriesSet.createSeries(SeriesType.LINE, Resources.getMessage("ViewStatisticsClassificationInput.15")); //$NON-NLS-1$ series1.getLabel().setVisible(false); series1.getLabel().setFont(chart.getFont()); series1.setLineColor(Display.getDefault().getSystemColor(SWT.COLOR_RED)); series1.setYSeries(ySeriesPrecision); series1.setAntialias(SWT.ON); series1.setSymbolType(PlotSymbolType.NONE); series1.enableArea(true); ILineSeries series2 = (ILineSeries) seriesSet.createSeries(SeriesType.LINE, Resources.getMessage("ViewStatisticsClassificationInput.16")); //$NON-NLS-1$ series2.getLabel().setVisible(false); series2.getLabel().setFont(chart.getFont()); series2.setLineColor(Display.getDefault().getSystemColor(SWT.COLOR_BLUE)); series2.setYSeries(ySeriesRecall); series2.setSymbolType(PlotSymbolType.NONE); series2.enableArea(true); seriesSet.bringToFront(Resources.getMessage("ViewStatisticsClassificationInput.16")); //$NON-NLS-1$ chart.getLegend().setVisible(true); chart.getLegend().setPosition(SWT.TOP); IAxisSet axisSet = chart.getAxisSet(); IAxis yAxis = axisSet.getYAxis(0); yAxis.setRange(new Range(0d, 100d)); IAxis xAxis = axisSet.getXAxis(0); xAxis.setCategorySeries(xAxisLabels); xAxis.adjustRange(); updateCategories(); chart.setRedraw(true); chart.updateLayout(); chart.update(); chart.redraw(); } /** * Makes the chart show category labels or not. */ private void updateCategories(){ if (chart != null){ IAxisSet axisSet = chart.getAxisSet(); if (axisSet != null) { IAxis xAxis = axisSet.getXAxis(0); if (xAxis != null) { String[] series = xAxis.getCategorySeries(); if (series != null) { boolean enoughSpace = chart.getPlotArea().getSize().x / series.length >= MIN_CATEGORY_WIDTH; xAxis.enableCategory(enoughSpace); xAxis.getTick().setVisible(enoughSpace); } } } } } @Override protected Control createControl(Composite parent) { // Root this.root = new Composite(parent, SWT.NONE); this.root.setLayout(new FillLayout()); // Shash this.sash = new SashForm(this.root, SWT.VERTICAL); // Table this.table = SWTUtil.createTableDynamic(this.sash, SWT.BORDER | SWT.V_SCROLL | SWT.H_SCROLL | SWT.FULL_SELECTION); this.table.setHeaderVisible(true); this.table.setLinesVisible(true); this.table.setMenu(new ClipboardHandlerTable(table).getMenu()); // Columns String[] columns = getColumnHeaders(); String width = String.valueOf(Math.round(100d / ((double) columns.length + 2) * 100d) / 100d) + "%"; //$NON-NLS-1$ DynamicTableColumn c = new DynamicTableColumn(table, SWT.LEFT); c.setWidth(width, "100px"); //$NON-NLS-1$ c.setText(Resources.getMessage("ViewStatisticsClassificationInput.0")); //$NON-NLS-1$ c = new DynamicTableColumn(table, SWT.LEFT); c.setWidth(width, "100px"); //$NON-NLS-1$ c.setText(Resources.getMessage("ViewStatisticsClassificationInput.2")); //$NON-NLS-1$ for (String column : columns) { c = new DynamicTableColumn(table, SWT.LEFT); SWTUtil.createColumnWithBarCharts(table, c); c.setWidth(width, "100px"); //$NON-NLS-1$ c.setText(column); } for (final TableColumn col : table.getColumns()) { col.pack(); } SWTUtil.createGenericTooltip(table); // Chart and sash resetChart(); this.sash.setWeights(new int[] {2, 2}); // Tool tip final StringBuilder builder = new StringBuilder(); this.sash.addListener(SWT.MouseMove, new Listener() { @Override public void handleEvent(Event event) { if (chart != null) { IAxisSet axisSet = chart.getAxisSet(); if (axisSet != null) { IAxis xAxis = axisSet.getXAxis(0); if (xAxis != null) { Point cursor = chart.getPlotArea().toControl(Display.getCurrent().getCursorLocation()); if (cursor.x >= 0 && cursor.x < chart.getPlotArea().getSize().x && cursor.y >= 0 && cursor.y < chart.getPlotArea().getSize().y) { String[] series = xAxis.getCategorySeries(); ISeries[] data = chart.getSeriesSet().getSeries(); if (data != null && data.length>0 && series != null) { int x = (int) Math.round(xAxis.getDataCoordinate(cursor.x)); if (x >= 0 && x < series.length && !series[x].equals("")) { builder.setLength(0); builder.append("("); //$NON-NLS-1$ builder.append(Resources.getMessage("ViewStatisticsClassificationInput.14")).append(": "); //$NON-NLS-1$ //$NON-NLS-2$ builder.append(series[x]); builder.append("%, ").append(Resources.getMessage("ViewStatisticsClassificationInput.15")).append(": "); //$NON-NLS-1$ //$NON-NLS-2$ //$NON-NLS-3$ builder.append(SWTUtil.getPrettyString(data[0].getYSeries()[x])); builder.append("%, ").append(Resources.getMessage("ViewStatisticsClassificationInput.16")).append(": "); //$NON-NLS-1$ //$NON-NLS-2$ //$NON-NLS-3$ builder.append(SWTUtil.getPrettyString(data[1].getYSeries()[x])); builder.append("%)"); //$NON-NLS-1$ sash.setToolTipText(builder.toString()); return; } } } } } sash.setToolTipText(null); } } }); // Update matrix table.addListener(SWT.MouseDown, new Listener() { public void handleEvent(Event event) { Rectangle clientArea = table.getClientArea(); Point pt = new Point(event.x, event.y); int index = table.getTopIndex(); while (index < table.getItemCount()) { boolean visible = false; TableItem item = table.getItem(index); for (int i = 0; i < table.getColumnCount(); i++) { Rectangle rect = item.getBounds(i); if (rect.contains(pt)) { if (item.getData() != null && item.getData() instanceof PrecisionRecallMatrix) { setChartSeries((PrecisionRecallMatrix) item.getData()); } getModel().setSelectedAttribute(item.getText(0)); getController().update(new ModelEvent(ViewStatisticsLogisticRegression.this, ModelPart.SELECTED_ATTRIBUTE, item.getText(0))); return; } if (!visible && rect.intersects(clientArea)) { visible = true; } } if (!visible) return; index++; } } }); return this.root; } @Override protected AnalysisContextClassification createViewConfig(AnalysisContext context) { return new AnalysisContextClassification(context); } @Override protected void doReset() { if (this.manager != null) { this.manager.stop(); } table.setRedraw(false); for (final TableItem i : table.getItems()) { i.dispose(); } table.setRedraw(true); resetChart(); setStatusEmpty(); } @Override protected void doUpdate(final AnalysisContextClassification context) { // The statistics builder final StatisticsBuilderInterruptible builder = context.handle.getStatistics().getInterruptibleInstance(); final String[] features = context.model.getSelectedFeatures().toArray(new String[0]); final String[] classes = context.model.getSelectedClasses().toArray(new String[0]); final ARXLogisticRegressionConfiguration config = context.model.getClassificationModel().getARXLogisticRegressionConfiguration(); // Break, if nothing do if (context.model.getSelectedFeatures().isEmpty() || context.model.getSelectedClasses().isEmpty()) { doReset(); return; } // Create an analysis Analysis analysis = new Analysis(){ private boolean stopped = false; private List<List<Double>> values = new ArrayList<>(); private List<Integer> numClasses = new ArrayList<>(); private List<PrecisionRecallMatrix> matrixes = new ArrayList<>(); private int progress = 0; @Override public int getProgress() { double result = 0d; double perBatch = 100d / (double)classes.length; result += (double)progress * perBatch; result += (double)builder.getProgress() / 100d * perBatch; result = result <= 100d ? result : 100d; return (int)result; } @Override public void onError() { setStatusEmpty(); } @Override public void onFinish() { // Check if (stopped || !isEnabled() || getModel().getSelectedFeatures().isEmpty() || getModel().getSelectedClasses().isEmpty()) { setStatusEmpty(); return; } // Update chart for (final TableItem i : table.getItems()) { i.dispose(); } // Create entries for (int i = 0; i < classes.length; i++) { TableItem item = new TableItem(table, SWT.NONE); item.setText(0, classes[i]); item.setText(1, String.valueOf(numClasses.get(i))); for (int j = 0; j<values.get(i).size(); j++) { item.setData(String.valueOf(2+j), values.get(i).get(j)); } item.setData(matrixes.get(i)); } table.setFocus(); table.select(0); setChartSeries(matrixes.get(0)); // Status root.layout(); sash.setWeights(new int[] {2, 2}); setStatusDone(); } @Override public void onInterrupt() { if (!isEnabled() || getModel().getSelectedFeatures().isEmpty() || getModel().getSelectedClasses().isEmpty()) { setStatusEmpty(); } else { setStatusWorking(); } } @Override public void run() throws InterruptedException { // Timestamp long time = System.currentTimeMillis(); // Do work for (String clazz : classes) { // Compute StatisticsClassification result = builder.getClassificationPerformance(features, clazz, config); progress++; if (stopped) { break; } numClasses.add(result.getNumClasses()); values.add(getColumnValues(result)); matrixes.add(result.getPrecisionRecall()); } // Our users are patient while (System.currentTimeMillis() - time < MINIMAL_WORKING_TIME && !stopped){ Thread.sleep(10); } } @Override public void stop() { builder.interrupt(); this.stopped = true; } }; this.manager.start(analysis); } /** * Returns all column headers * @return */ protected abstract String[] getColumnHeaders(); /** * Returns all values for one row * @param result * @return */ protected abstract List<Double> getColumnValues(StatisticsClassification result); @Override protected ComponentStatusLabelProgressProvider getProgressProvider() { return new ComponentStatusLabelProgressProvider(){ public int getProgress() { if (manager == null) { return 0; } else { return manager.getProgress(); } } }; } /** * Is an analysis running */ protected boolean isRunning() { return manager != null && manager.isRunning(); } }