/*
* RapidMiner
*
* Copyright (C) 2001-2008 by Rapid-I and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapid-i.com
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.learner.functions.kernel.hyperhyper;
import java.awt.Component;
import java.util.Iterator;
import com.rapidminer.datatable.DataTable;
import com.rapidminer.datatable.SimpleDataTable;
import com.rapidminer.datatable.SimpleDataTableRow;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.gui.tools.ExtendedJScrollPane;
import com.rapidminer.gui.tools.JRadioSelectionPanel;
import com.rapidminer.gui.viewer.DataTableViewerTable;
import com.rapidminer.operator.IOContainer;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.tools.Tools;
/**
* The model for the HyperHyper implementation.
*
* @author Regina Fritsch
* @version $Id: HyperModel.java,v 1.4 2008/07/04 10:27:17 stiefelolm Exp $
*/
public class HyperModel extends PredictionModel {
private static final long serialVersionUID = -453402008180607969L;
private String[] coefficientNames;
private double[] x1;
private double[] x2;
private double bias;
private double[] w;
public HyperModel(ExampleSet trainingExampleSet, double bias, double[] w, double[] x1, double[] x2) {
super(trainingExampleSet);
this.coefficientNames = com.rapidminer.example.Tools.getRegularAttributeNames(trainingExampleSet);
this.bias = bias;
this.w = w;
this.x1 = x1;
this.x2 = x2;
}
public int getNumberOfAttributes() {
return x1.length;
}
public String toString() {
StringBuffer result = new StringBuffer();
result.append("Support Vector 1:" + Tools.getLineSeparator());
for (int i = 0; i < this.coefficientNames.length; i++) {
result.append(coefficientNames[i] + " = " + Tools.formatNumber(this.x1[i]) + Tools.getLineSeparator());
}
result.append(Tools.getLineSeparator() + "Support Vector 2:" + Tools.getLineSeparator());
for (int i = 0; i < this.coefficientNames.length; i++) {
result.append(this.coefficientNames[i] + " = " + Tools.formatNumber(this.x2[i]) + Tools.getLineSeparator());
}
result.append(Tools.getLineSeparator() + "Bias (offset): " + Tools.formatNumber(this.bias) + Tools.getLineSeparators(2));
result.append("Coefficients:" + Tools.getLineSeparator());
for (int j = 0; j < w.length; j++) {
result.append("w(" + this.coefficientNames[j] + ") = " + Tools.formatNumber(this.w[j]) + Tools.getLineSeparator());
}
return result.toString();
}
public String getName() {
return ("HyperHyper Model");
}
public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws OperatorException {
if (exampleSet.getAttributes().size() != getNumberOfAttributes())
throw new UserError(null, 133, getNumberOfAttributes(), exampleSet.getAttributes().size());
Iterator<Example> reader = exampleSet.iterator();
while (reader.hasNext()) {
Example activeExample = reader.next();
double sum = 0;
int i = 0;
for (Attribute attribute : exampleSet.getAttributes()) {
sum += activeExample.getValue(attribute) * this.w[i];
i++;
}
double result = sum + this.bias; // <w * x> + b
int prediction = 0;
if (result > 0) {
prediction = getLabel().getMapping().getPositiveIndex();
} else {
prediction = getLabel().getMapping().getNegativeIndex();
}
activeExample.setValue(predictedLabel, prediction);
activeExample.setConfidence(predictedLabel.getMapping().getPositiveString(), 1.0d / (1.0d + java.lang.Math.exp(-result)));
activeExample.setConfidence(predictedLabel.getMapping().getNegativeString(), 1.0d / (1.0d + java.lang.Math.exp(result)));
}
return exampleSet;
}
private DataTable createWeightsTable() {
SimpleDataTable weightTable = new SimpleDataTable("Hyper Weights", new String[] { "Attribute", "Weight" } );
for (int j = 0; j < w.length; j++) {
int nameIndex = weightTable.mapString(0, this.coefficientNames[j]);
weightTable.add(new SimpleDataTableRow(new double[] { nameIndex, w[j]}));
}
return weightTable;
}
/** Returns a html label with a table view or a plotter for statistic view. */
public Component getVisualizationComponent(IOContainer container) {
final JRadioSelectionPanel mainPanel = new JRadioSelectionPanel();
// text view
Component textView = super.getVisualizationComponent(container);
mainPanel.addComponent("Text View", textView, "Changes to a textual view of this model.");
// weight table
DataTable weightDataTable = createWeightsTable();
DataTableViewerTable weightTableViewer = new DataTableViewerTable(true);
weightTableViewer.setDataTable(weightDataTable);
Component weightTableView = new ExtendedJScrollPane(weightTableViewer);
mainPanel.addComponent("Weight Table View", weightTableView, "Changes to a weight table view of this model.");
return mainPanel;
}
}