/*
* Copyright (C) 2016 by Array Systems Computing Inc. http://www.array.ca
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the Free
* Software Foundation; either version 3 of the License, or (at your option)
* any later version.
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program; if not, see http://www.gnu.org/licenses/
*/
package org.esa.snap.classification.gpf.ui;
import org.esa.snap.classification.gpf.BaseClassifier;
import org.esa.snap.engine_utilities.util.VectorUtils;
import org.esa.snap.core.datamodel.Product;
import org.esa.snap.core.util.SystemUtils;
import org.esa.snap.core.util.io.FileUtils;
import org.esa.snap.graphbuilder.gpf.ui.BaseOperatorUI;
import org.esa.snap.graphbuilder.gpf.ui.OperatorUIUtils;
import org.esa.snap.graphbuilder.gpf.ui.UIValidation;
import org.esa.snap.graphbuilder.rcp.utils.DialogUtils;
import org.esa.snap.rcp.util.Dialogs;
import org.esa.snap.ui.AppContext;
import javax.swing.*;
import javax.swing.border.TitledBorder;
import java.awt.*;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.ItemEvent;
import java.awt.event.ItemListener;
import java.io.File;
import java.io.FileFilter;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
/**
* User interface for classifiers
*/
public abstract class BaseClassifierOpUI extends BaseOperatorUI {
private final JRadioButton loadBtn = new JRadioButton("Load and apply classifier", false);
private final JRadioButton trainBtn = new JRadioButton("Train and apply classifier", true);
private final JComboBox<String> classifierNameComboBox = new JComboBox();
private final JButton deleteClassiferBtn = new JButton("X");
private final JTextField newClassifierNameField = new JTextField("newClassifier");
private final JRadioButton trainOnRasterBtn = new JRadioButton("Train on Raster", false);
private final JRadioButton trainOnVectorsBtn = new JRadioButton("Train on Vectors", true);
private final JTextField numTrainSamples = new JTextField("");
private final JCheckBox evaluateClassifier = new JCheckBox("");
private final JCheckBox evaluateFeaturePowerSet = new JCheckBox("");
private final JCheckBox doClassValQuantization = new JCheckBox();
private final JTextField minClassValue = new JTextField("");
private final JTextField classValStepSize = new JTextField("");
private final JTextField classLevels = new JTextField("");
private final JLabel maxClassValue = new JLabel("");
private final JRadioButton labelSourceVectorName = new JRadioButton("Vector node name", true);
private final JRadioButton labelSourceAttribute = new JRadioButton("Attribute value", false);
private final JList<String> trainingBands = new JList();
private final JList<String> trainingVectors = new JList();
private final JList<String> featureBandNames = new JList();
protected JPanel classifierPanel, rasterPanel, vectorPanel, featurePanel;
protected GridBagConstraints classifiergbc;
private final String classifierType;
public BaseClassifierOpUI(final String classifierType) {
this.classifierType = classifierType;
}
@Override
public JComponent CreateOpTab(String operatorName, Map<String, Object> parameterMap, AppContext appContext) {
initializeOperatorUI(operatorName, parameterMap);
final JComponent panel = createPanel();
doClassValQuantization.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
enableQuantization(doClassValQuantization.isSelected());
}
});
minClassValue.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
updateMaxClassValue();
}
});
classValStepSize.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
updateMaxClassValue();
}
});
classLevels.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
updateMaxClassValue();
}
});
evaluateClassifier.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
evaluateFeaturePowerSet.setEnabled(evaluateClassifier.isSelected());
}
});
loadBtn.addItemListener(new ItemListener() {
@Override
public void itemStateChanged(ItemEvent e) {
boolean doTraining = e.getStateChange() != ItemEvent.SELECTED;
enableTraining(doTraining);
enableTrainOnRaster(doTraining, trainOnRasterBtn.isSelected());
}
});
labelSourceAttribute.addItemListener(new ItemListener() {
@Override
public void itemStateChanged(ItemEvent e) {
if (e.getStateChange() == ItemEvent.SELECTED) {
final AttributeDialog dlg = new AttributeDialog("Labels from Attribute",
VectorUtils.getAttributesList(sourceProducts), null);
dlg.show();
if (dlg.IsOK()) {
labelSourceAttribute.setText(dlg.getValue());
}
}
}
});
populateClassifierNames();
classifierNameComboBox.setEditable(false);
classifierNameComboBox.setMaximumRowCount(5);
deleteClassiferBtn.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
requestDeleteClassifier();
}
});
trainingBands.setSelectionMode(ListSelectionModel.SINGLE_SELECTION);
trainingVectors.setSelectionMode(ListSelectionModel.MULTIPLE_INTERVAL_SELECTION);
trainOnRasterBtn.addItemListener(new ItemListener() {
@Override
public void itemStateChanged(ItemEvent e) {
enableTrainOnRaster(trainBtn.isSelected(), e.getStateChange() == ItemEvent.SELECTED);
}
});
initParameters();
return new JScrollPane(panel);
}
private Path getClassifierFolder() {
return SystemUtils.getAuxDataPath().
resolve(BaseClassifier.CLASSIFIER_ROOT_FOLDER).resolve(classifierType);
}
private void populateClassifierNames() {
final Path classifierDir = getClassifierFolder();
final File folder = new File(classifierDir.toString());
final File[] listOfFiles = folder.listFiles(new FileFilter() {
public boolean accept(File pathname) {
return pathname.isFile() && pathname.getName().endsWith(BaseClassifier.CLASSIFIER_FILE_EXTENSION);
}
});
if (listOfFiles != null && listOfFiles.length > 0) {
for (File file : listOfFiles) {
classifierNameComboBox.addItem(FileUtils.getFilenameWithoutExtension(file));
}
classifierNameComboBox.setSelectedIndex(0);
} else {
trainBtn.setSelected(true);
}
}
private void requestDeleteClassifier() {
String name = (String) classifierNameComboBox.getSelectedItem();
if (name != null) {
Dialogs.Answer answer = Dialogs.requestDecision("Delete Classifier",
"Are you sure you want to delete classifier " + name,
true, null);
if (answer.equals(Dialogs.Answer.YES)) {
final Path classifierDir = getClassifierFolder();
final File classiferFile = classifierDir.resolve(name + BaseClassifier.CLASSIFIER_FILE_EXTENSION).toFile();
if (classiferFile.exists()) {
if (deleteClassifier(classiferFile, name)) {
classifierNameComboBox.removeItem(name);
if (classifierNameComboBox.getItemCount() == 0) {
trainBtn.setSelected(true);
}
} else {
Dialogs.showError("Unable to delete classifier " + classiferFile);
}
} else {
Dialogs.showError("Unable to find classifier " + classiferFile);
}
}
}
}
private static boolean deleteClassifier(final File classifierFile, final String classifierName) {
boolean ok = classifierFile.delete();
// find other associated files
final File[] files = classifierFile.getParentFile().listFiles(new FileFilter() {
@Override
public boolean accept(File pathname) {
return FileUtils.getFilenameWithoutExtension(pathname).equals(classifierName);
}
});
if (files != null) {
for (File file : files) {
file.delete();
}
}
return ok;
}
protected abstract void setEnabled(final boolean enabled);
private void enableTrainOnRaster(final boolean doTraining, final boolean trainOnRaster) {
if (doTraining) {
if (trainOnRaster) {
OperatorUIUtils.initParamList(trainingBands, getTrainingBands());
} else {
OperatorUIUtils.initParamList(trainingVectors, getPolygons());
}
}
rasterPanel.setVisible(doTraining && trainOnRaster);
vectorPanel.setVisible(doTraining && !trainOnRaster);
featurePanel.setVisible(doTraining);
}
private void enableQuantization(final boolean enable) {
minClassValue.setEnabled(enable);
classValStepSize.setEnabled(enable);
classLevels.setEnabled(enable);
maxClassValue.setEnabled(enable);
}
private void setEnableDoClassValQuantization(final boolean enable) {
doClassValQuantization.setEnabled(enable);
minClassValue.setEnabled(enable && doClassValQuantization.isSelected());
classValStepSize.setEnabled(enable && doClassValQuantization.isSelected());
classLevels.setEnabled(enable && doClassValQuantization.isSelected());
maxClassValue.setEnabled(enable && doClassValQuantization.isSelected());
}
private void enableTraining(boolean doTraining) {
classifierNameComboBox.setEnabled(!doTraining);
deleteClassiferBtn.setEnabled(!doTraining);
newClassifierNameField.setEnabled(doTraining);
setEnableDoClassValQuantization(doTraining);
trainOnRasterBtn.setEnabled(doTraining);
trainOnVectorsBtn.setEnabled(doTraining);
numTrainSamples.setEnabled(doTraining);
trainingBands.setEnabled(doTraining);
if (!trainingBands.isEnabled()) {
trainingBands.clearSelection();
}
trainingVectors.setEnabled(doTraining);
if (!trainingVectors.isEnabled()) {
trainingVectors.clearSelection();
}
featureBandNames.setEnabled(doTraining);
setEnabled(doTraining);
}
private void updateMaxClassValue() {
final double minVal = Double.parseDouble(minClassValue.getText());
final double stepSize = Double.parseDouble(classValStepSize.getText());
final int levels = Integer.parseInt(classLevels.getText());
final double maxClassVal = BaseClassifier.getMaxValue(minVal, stepSize, levels);
maxClassValue.setText(String.valueOf(maxClassVal));
}
@Override
public void initParameters() {
String newClassifierName = (String) paramMap.get("savedClassifierName");
if (DialogUtils.contains(classifierNameComboBox, newClassifierName)) {
classifierNameComboBox.setSelectedItem(newClassifierName);
}
String numSamples = String.valueOf(paramMap.get("numTrainSamples"));
numTrainSamples.setText(numSamples);
Boolean eval = (Boolean) (paramMap.get("evaluateClassifier"));
if (eval != null) {
evaluateClassifier.setSelected(eval);
}
Boolean evalPS = (Boolean) (paramMap.get("evaluateFeaturePowerSet"));
if (evalPS != null) {
evaluateFeaturePowerSet.setSelected(evalPS);
}
Boolean doQuant = (Boolean) (paramMap.get("doClassValQuantization"));
if (doQuant != null) {
doClassValQuantization.setSelected(doQuant);
}
minClassValue.setText(String.valueOf(paramMap.get("minClassValue")));
classValStepSize.setText(String.valueOf(paramMap.get("classValStepSize")));
classLevels.setText(String.valueOf(paramMap.get("classLevels")));
final Double minVal = (Double) paramMap.get("minClassValue");
final Double stepSize = (Double) paramMap.get("classValStepSize");
final Integer levels = (Integer) paramMap.get("classLevels");
if(minVal != null && stepSize != null && levels != null) {
final double maxClassVal = BaseClassifier.getMaxValue(minVal, stepSize, levels);
maxClassValue.setText(String.valueOf(maxClassVal));
}
Boolean trainOnRastersVal = (Boolean) paramMap.get("trainOnRaster");
boolean trainOnRasters = trainOnRastersVal != null && trainOnRastersVal;
trainOnRasterBtn.setSelected(trainOnRasters);
String labelSource = (String) paramMap.get("labelSource");
if (labelSource == null || labelSource.equals(BaseClassifier.VectorNodeNameLabelSource)) {
labelSourceVectorName.setSelected(true);
}
boolean doTraining = true;
enableTraining(doTraining);
enableTrainOnRaster(doTraining, trainOnRasters);
paramMap.put("bandsOrVectors", null);
OperatorUIUtils.initParamList(featureBandNames, getFeatures());
}
@Override
public UIValidation validateParameters() {
if (!loadBtn.isSelected()) {
if (DialogUtils.contains(classifierNameComboBox, newClassifierNameField.getText())) {
// return new UIValidation(UIValidation.State.ERROR, "Name already in use. Please select a unique classifier name");
}
}
return new UIValidation(UIValidation.State.OK, "");
}
@Override
public void updateParameters() {
paramMap.put("numTrainSamples", Integer.parseInt(numTrainSamples.getText()));
paramMap.put("evaluateClassifier", evaluateClassifier.isSelected());
paramMap.put("evaluateFeaturePowerSet", evaluateFeaturePowerSet.isSelected());
paramMap.put("doClassValQuantization", doClassValQuantization.isSelected());
paramMap.put("minClassValue", Double.parseDouble(minClassValue.getText()));
paramMap.put("classValStepSize", Double.parseDouble(classValStepSize.getText()));
paramMap.put("classLevels", Integer.parseInt(classLevels.getText()));
paramMap.put("trainOnRaster", trainOnRasterBtn.isSelected());
String classifierName = loadBtn.isSelected() ?
(String) classifierNameComboBox.getSelectedItem() :
newClassifierNameField.getText();
paramMap.put("savedClassifierName", classifierName);
if (labelSourceAttribute.isSelected()) {
paramMap.put("labelSource", labelSourceAttribute.getText());
} else {
paramMap.put("labelSource", BaseClassifier.VectorNodeNameLabelSource);
}
OperatorUIUtils.updateParamList(trainingBands, paramMap, "trainingBands");
//dumpSelectedValues("trainingBands", trainingBands);
OperatorUIUtils.updateParamList(trainingVectors, paramMap, "trainingVectors");
//dumpSelectedValues("trainingVectors", trainingVectors);
OperatorUIUtils.updateParamList(featureBandNames, paramMap, "featureBands");
//dumpSelectedValues("features", featureBandNames);
}
private static void dumpSelectedValues(final String name, final JList<String> paramList) {
SystemUtils.LOG.info(name + " selected values:");
final List<String> selectedValues = paramList.getSelectedValuesList();
for (Object selectedValue : selectedValues) {
SystemUtils.LOG.info(' ' + (String) selectedValue);
}
}
protected JPanel createPanel() {
final JPanel contentPane = new JPanel();
contentPane.setLayout(new GridBagLayout());
GridBagConstraints gbc = DialogUtils.createGridBagConstraints();
classifierPanel = createClassifierPanel();
gbc.gridy++;
contentPane.add(classifierPanel, gbc);
rasterPanel = createRasterPanel();
gbc.gridy++;
contentPane.add(rasterPanel, gbc);
vectorPanel = createVectorPanel();
contentPane.add(vectorPanel, gbc);
featurePanel = createFeaturePanel();
gbc.gridy++;
contentPane.add(featurePanel, gbc);
DialogUtils.fillPanel(contentPane, gbc);
return contentPane;
}
private JPanel createClassifierPanel() {
final JPanel classifierPanel = new JPanel();
classifierPanel.setLayout(new GridBagLayout());
classifierPanel.setBorder(new TitledBorder("Classifier"));
classifiergbc = DialogUtils.createGridBagConstraints();
final ButtonGroup group1 = new ButtonGroup();
group1.add(trainBtn);
group1.add(loadBtn);
classifierPanel.add(trainBtn, classifiergbc);
classifiergbc.gridx = 1;
classifierPanel.add(newClassifierNameField, classifiergbc);
classifiergbc.gridx = 0;
classifiergbc.gridx = 0;
classifiergbc.gridy++;
classifierPanel.add(loadBtn, classifiergbc);
classifiergbc.gridx = 1;
classifierPanel.add(classifierNameComboBox, classifiergbc);
classifiergbc.gridx = 2;
classifierPanel.add(deleteClassiferBtn, classifiergbc);
final ButtonGroup group2 = new ButtonGroup();
group2.add(trainOnRasterBtn);
group2.add(trainOnVectorsBtn);
JPanel radioPanel = new JPanel(new FlowLayout());
radioPanel.add(trainOnRasterBtn);
radioPanel.add(trainOnVectorsBtn);
classifiergbc.gridy++;
classifiergbc.gridx = 1;
classifierPanel.add(radioPanel, classifiergbc);
classifiergbc.gridx = 0;
classifiergbc.gridy++;
DialogUtils.addComponent(classifierPanel, classifiergbc, "Evaluate classifier", evaluateClassifier);
classifiergbc.gridy++;
DialogUtils.addComponent(classifierPanel, classifiergbc, "Evaluate Feature Power Set", evaluateFeaturePowerSet);
classifiergbc.gridy++;
DialogUtils.addComponent(classifierPanel, classifiergbc, "Number of training samples", numTrainSamples);
DialogUtils.fillPanel(classifierPanel, classifiergbc);
return classifierPanel;
}
private JPanel createRasterPanel() {
final JPanel rasterPanel = new JPanel();
rasterPanel.setLayout(new GridBagLayout());
rasterPanel.setBorder(new TitledBorder("Raster Training"));
GridBagConstraints gbc = DialogUtils.createGridBagConstraints();
gbc.gridy++;
DialogUtils.addComponent(rasterPanel, gbc, "Quantize class value", doClassValQuantization);
gbc.gridy++;
DialogUtils.addComponent(rasterPanel, gbc, "Min class value", minClassValue);
gbc.gridy++;
DialogUtils.addComponent(rasterPanel, gbc, "Class value step size", classValStepSize);
gbc.gridy++;
DialogUtils.addComponent(rasterPanel, gbc, "Class levels", classLevels);
gbc.gridy++;
DialogUtils.addComponent(rasterPanel, gbc, "Max class value", maxClassValue);
gbc.gridy++;
DialogUtils.addComponent(rasterPanel, gbc, "Training band:", new JScrollPane(trainingBands));
DialogUtils.fillPanel(rasterPanel, gbc);
return rasterPanel;
}
private JPanel createVectorPanel() {
final JPanel vectorPanel = new JPanel();
vectorPanel.setLayout(new GridBagLayout());
vectorPanel.setBorder(new TitledBorder("Vector Training"));
GridBagConstraints gbc = DialogUtils.createGridBagConstraints();
gbc.gridy++;
DialogUtils.addComponent(vectorPanel, gbc, "Training vectors: ", new JScrollPane(trainingVectors));
gbc.gridy++;
gbc.gridx = 0;
vectorPanel.add(new JLabel("Labels:"), gbc);
final ButtonGroup group3 = new ButtonGroup();
group3.add(labelSourceVectorName);
group3.add(labelSourceAttribute);
JPanel radioPanel = new JPanel(new FlowLayout());
radioPanel.add(labelSourceVectorName);
radioPanel.add(labelSourceAttribute);
gbc.gridx = 1;
vectorPanel.add(radioPanel, gbc);
DialogUtils.fillPanel(vectorPanel, gbc);
return vectorPanel;
}
private JPanel createFeaturePanel() {
final JPanel featurePanel = new JPanel();
featurePanel.setBorder(new TitledBorder("Feature Selection"));
featurePanel.setLayout(new GridBagLayout());
GridBagConstraints gbc = DialogUtils.createGridBagConstraints();
DialogUtils.addComponent(featurePanel, gbc, "Feature bands: ", new JScrollPane(featureBandNames));
DialogUtils.fillPanel(featurePanel, gbc);
return featurePanel;
}
private String[] getPolygons() {
// Get polygons from the first product which is assumed to be maskProduct in BaseClassifier
final ArrayList<String> geometryNames = new ArrayList<>(5);
if (sourceProducts != null) {
if (sourceProducts.length > 1) {
for (String name : sourceProducts[0].getMaskGroup().getNodeNames()) {
geometryNames.add(name + "::" + sourceProducts[0].getName());
}
} else {
geometryNames.addAll(Arrays.asList(sourceProducts[0].getMaskGroup().getNodeNames()));
}
}
return geometryNames.toArray(new String[geometryNames.size()]);
}
private String[] getTrainingBands() {
final ArrayList<String> bandNames = new ArrayList<>(5);
if (sourceProducts != null) {
if (sourceProducts.length > 1) {
for (String name : sourceProducts[0].getBandNames()) {
bandNames.add(name + "::" + sourceProducts[0].getName());
}
} else {
bandNames.addAll(Arrays.asList(sourceProducts[0].getBandNames()));
}
}
return bandNames.toArray(new String[bandNames.size()]);
}
private String[] getFeatures() {
final ArrayList<String> featureNames = new ArrayList<>(5);
if (sourceProducts != null) {
for (Product prod : sourceProducts) {
for (String name : prod.getBandNames()) {
if (BaseClassifier.excludeBand(name))
continue;
if (sourceProducts.length > 1) {
featureNames.add(name + "::" + prod.getName());
} else {
featureNames.add(name);
}
}
}
}
return featureNames.toArray(new String[featureNames.size()]);
}
}