/*******************************************************************************
* Copyright (c) 2010 Haifeng Li
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
package smile.demo.classification;
import java.awt.Dimension;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JOptionPane;
import javax.swing.JTextField;
import smile.classification.LogisticRegression;
/**
*
* @author Haifeng Li
*/
@SuppressWarnings("serial")
public class LogisticRegressionDemo extends ClassificationDemo {
private double lambda = 0.1;
private JTextField lambdaField;
/**
* Constructor.
*/
public LogisticRegressionDemo() {
lambdaField = new JTextField(Double.toString(lambda), 5);
optionPane.add(new JLabel("\u03BB:"));
optionPane.add(lambdaField);
}
@Override
public double[][] learn(double[] x, double[] y) {
try {
lambda = Double.parseDouble(lambdaField.getText().trim());
if (lambda < 0.0) {
JOptionPane.showMessageDialog(this, "Invalid \u03BB: " + lambda, "Error", JOptionPane.ERROR_MESSAGE);
return null;
}
} catch (Exception ex) {
JOptionPane.showMessageDialog(this, "Invalid \u03BB: " + lambdaField.getText(), "Error", JOptionPane.ERROR_MESSAGE);
return null;
}
double[][] data = dataset[datasetIndex].toArray(new double[dataset[datasetIndex].size()][]);
int[] label = dataset[datasetIndex].toArray(new int[dataset[datasetIndex].size()]);
LogisticRegression logit = new LogisticRegression(data, label, lambda);
for (int i = 0; i < label.length; i++) {
label[i] = logit.predict(data[i]);
}
double trainError = error(label, label);
System.out.format("training error = %.2f%%\n", 100*trainError);
double[][] z = new double[y.length][x.length];
for (int i = 0; i < y.length; i++) {
for (int j = 0; j < x.length; j++) {
double[] p = {x[j], y[i]};
z[i][j] = logit.predict(p);
}
}
return z;
}
@Override
public String toString() {
return "Logistic Regression";
}
public static void main(String argv[]) {
ClassificationDemo demo = new LogisticRegressionDemo();
JFrame f = new JFrame("Logistic Regression");
f.setSize(new Dimension(1000, 1000));
f.setLocationRelativeTo(null);
f.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
f.getContentPane().add(demo);
f.setVisible(true);
}
}