/*******************************************************************************
* Copyright (c) 2010 Haifeng Li
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
package smile.demo.projection;
import java.awt.Dimension;
import java.awt.GridLayout;
import javax.swing.JComponent;
import javax.swing.JFrame;
import javax.swing.JPanel;
import smile.plot.Palette;
import smile.plot.PlotCanvas;
import smile.projection.GHA;
import smile.projection.PCA;
import smile.math.Math;
/**
*
* @author Haifeng Li
*/
@SuppressWarnings("serial")
public class GHADemo extends ProjectionDemo {
public GHADemo() {
}
@Override
public JComponent learn() {
JPanel pane = new JPanel(new GridLayout(2, 2));
double[][] data = Math.clone(dataset[datasetIndex].toArray(new double[dataset[datasetIndex].size()][]));
String[] names = dataset[datasetIndex].toArray(new String[dataset[datasetIndex].size()]);
if (names[0] == null) {
names = null;
}
long clock = System.currentTimeMillis();
PCA pca = new PCA(data, true);
System.out.format("Learn PCA from %d samples in %dms\n", data.length, System.currentTimeMillis() - clock);
pca.setProjection(2);
double[][] y = pca.project(data);
PlotCanvas plot = new PlotCanvas(Math.colMin(y), Math.colMax(y));
if (names != null) {
plot.points(y, names);
} else if (dataset[datasetIndex].response() != null) {
int[] labels = dataset[datasetIndex].toArray(new int[dataset[datasetIndex].size()]);
for (int i = 0; i < y.length; i++) {
plot.point(pointLegend, Palette.COLORS[labels[i]], y[i]);
}
} else {
plot.points(y, pointLegend);
}
plot.setTitle("PCA");
pane.add(plot);
pca.setProjection(3);
y = pca.project(data);
plot = new PlotCanvas(Math.colMin(y), Math.colMax(y));
if (names != null) {
plot.points(y, names);
} else if (dataset[datasetIndex].response() != null) {
int[] labels = dataset[datasetIndex].toArray(new int[dataset[datasetIndex].size()]);
for (int i = 0; i < y.length; i++) {
plot.point(pointLegend, Palette.COLORS[labels[i]], y[i]);
}
} else {
plot.points(y, pointLegend);
}
plot.setTitle("PCA");
pane.add(plot);
clock = System.currentTimeMillis();
GHA gha = new GHA(data[0].length, 2, 0.00001);
for (int iter = 1; iter <= 500; iter++) {
double error = 0.0;
for (int i = 0; i < data.length; i++) {
error += gha.learn(data[i]);
}
error /= data.length;
if (iter % 100 == 0) {
System.out.format("Iter %3d, Error = %.5g\n", iter, error);
}
}
System.out.format("Learn GHA from %d samples in %dms\n", data.length, System.currentTimeMillis() - clock);
y = gha.project(data);
plot = new PlotCanvas(Math.colMin(y), Math.colMax(y));
if (names != null) {
plot.points(y, names);
} else if (dataset[datasetIndex].response() != null) {
int[] labels = dataset[datasetIndex].toArray(new int[dataset[datasetIndex].size()]);
for (int i = 0; i < y.length; i++) {
plot.point(pointLegend, Palette.COLORS[labels[i]], y[i]);
}
} else {
plot.points(y, pointLegend);
}
plot.setTitle("GHA");
pane.add(plot);
clock = System.currentTimeMillis();
gha = new GHA(data[0].length, 3, 0.00001);
for (int iter = 1; iter <= 500; iter++) {
double error = 0.0;
for (int i = 0; i < data.length; i++) {
error += gha.learn(data[i]);
}
error /= data.length;
if (iter % 100 == 0) {
System.out.format("Iter %3d, Error = %.5g\n", iter, error);
}
}
System.out.format("Learn GHA from %d samples in %dms\n", data.length, System.currentTimeMillis() - clock);
y = gha.project(data);
plot = new PlotCanvas(Math.colMin(y), Math.colMax(y));
if (names != null) {
plot.points(y, names);
} else if (dataset[datasetIndex].response() != null) {
int[] labels = dataset[datasetIndex].toArray(new int[dataset[datasetIndex].size()]);
for (int i = 0; i < y.length; i++) {
plot.point(pointLegend, Palette.COLORS[labels[i]], y[i]);
}
} else {
plot.points(y, pointLegend);
}
plot.setTitle("GHA");
pane.add(plot);
return pane;
}
@Override
public String toString() {
return "Generalized Hebbian Algorithm";
}
public static void main(String argv[]) {
GHADemo demo = new GHADemo();
JFrame f = new JFrame("Generalized Hebbian Algorithm");
f.setSize(new Dimension(1000, 1000));
f.setLocationRelativeTo(null);
f.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
f.getContentPane().add(demo);
f.setVisible(true);
}
}