/* * RapidMiner * * Copyright (C) 2001-2008 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.learner.weka; import java.awt.BorderLayout; import java.awt.Component; import java.awt.FlowLayout; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import java.util.Iterator; import javax.swing.ButtonGroup; import javax.swing.JPanel; import javax.swing.JRadioButton; import javax.swing.JToolBar; import com.rapidminer.example.Attribute; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.operator.IOContainer; import com.rapidminer.operator.OperatorException; import com.rapidminer.operator.UserError; import com.rapidminer.operator.learner.PredictionModel; import com.rapidminer.tools.Tools; import com.rapidminer.tools.WekaInstancesAdaptor; import com.rapidminer.tools.WekaTools; import weka.classifiers.Classifier; import weka.classifiers.UpdateableClassifier; import weka.core.Drawable; import weka.core.Instance; import weka.core.Instances; import weka.gui.graphvisualizer.GraphVisualizer; import weka.gui.treevisualizer.PlaceNode2; import weka.gui.treevisualizer.TreeDisplayEvent; import weka.gui.treevisualizer.TreeDisplayListener; import weka.gui.treevisualizer.TreeVisualizer; /** * A Weka {@link weka.classifiers.Classifier} which can be used to classify * {@link Example}s. It is learned by the {@link GenericWekaLearner} and the * {@link GenericWekaMetaLearner}. * * @author Ingo Mierswa * @version $Id: WekaClassifier.java,v 1.8 2008/07/07 07:06:44 ingomierswa Exp $ */ public class WekaClassifier extends PredictionModel { private static final long serialVersionUID = -2684252543419537079L; /** The used weka classifier. */ private Classifier classifier; /** The name of the classifier. */ private String name; public WekaClassifier(ExampleSet exampleSet, String name, Classifier classifier) { super(exampleSet); this.name = name; this.classifier = classifier; } public Classifier getClassifier() { return this.classifier; } /** Returns true if the Weka classifier is updatable. */ public boolean isUpdatable() { return (classifier instanceof UpdateableClassifier); } /** Updates the model if the classifier is updatable. Otherwise, an * {@link UnsupportedOperationException} is thrown. */ public void updateModel(ExampleSet updateExampleSet) throws OperatorException { if (classifier instanceof UpdateableClassifier) { UpdateableClassifier updateableClassifier = (UpdateableClassifier)classifier; updateClassifier(updateableClassifier, updateExampleSet); } else { throw new UserError(null, 135, getClass().getName() + " (" + classifier.getClass() + ")"); } } private void updateClassifier(UpdateableClassifier classifier, ExampleSet exampleSet) throws OperatorException { log("Update Weka classifier."); log("Converting to Weka instances."); Instances instances = WekaTools.toWekaInstances(exampleSet, "UpdateInstances", WekaInstancesAdaptor.LEARNING); log("Actually updating Weka classifier."); try { for (int i = 0; i < instances.numInstances(); i++) { Instance instance = instances.instance(i++); classifier.updateClassifier(instance); } } catch (Exception e) { throw new UserError(null, 310, "updating Weka model", e.getMessage()); } } public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws OperatorException { log("Applying Weka classifier."); log("Converting to Weka instances."); Instances instances = WekaTools.toWekaInstances(exampleSet, "ApplierInstances", WekaInstancesAdaptor.PREDICTING); log("Actually applying Weka classifier."); int i = 0; Iterator<Example> r = exampleSet.iterator(); while (r.hasNext()) { Example e = r.next(); Instance instance = instances.instance(i++); applyModelForInstance(instance, e, predictedLabel); } return exampleSet; } /** * Classifies ervery weka instance and sets the result as predicted label of * the current example. */ public void applyModelForInstance(Instance instance, Example e, Attribute predictedLabelAttribute) { double predictedLabel = Double.NaN; try { double wekaPrediction = classifier.classifyInstance(instance); if (predictedLabelAttribute.isNominal()) { double confidences[] = classifier.distributionForInstance(instance); for (int i = 0; i < confidences.length; i++) { String classification = instance.classAttribute().value(i); e.setConfidence(classification, confidences[i]); } String classification = instance.classAttribute().value((int) wekaPrediction); predictedLabel = predictedLabelAttribute.getMapping().mapString(classification); } else { predictedLabel = classifier.classifyInstance(instance); } } catch (Exception exc) { logError("Exception occured while classifying example:" + exc.getMessage() + " [" + exc.getClass() + "]"); } e.setValue(predictedLabelAttribute, predictedLabel); } public String getName() { return this.name; } public String toString() { return this.name + " (model for label " + getLabel() + ")" + Tools.getLineSeparator() + classifier.toString(); } public String toResultString() { return classifier.toString(); } private Component createTextAndGraphView(final Component textView, final Component graphView) { final JPanel mainPanel = new JPanel(); mainPanel.setLayout(new BorderLayout()); final JRadioButton graphViewButton = new JRadioButton("Graph View", true); graphViewButton.setToolTipText("Changes to a graphical view of this model."); graphViewButton.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { if (graphViewButton.isSelected()) { mainPanel.remove(1); mainPanel.add(graphView, BorderLayout.CENTER); mainPanel.repaint(); } } }); final JRadioButton textViewButton = new JRadioButton("Text View", true); textViewButton.setToolTipText("Changes to a textual view of this model."); textViewButton.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { if (textViewButton.isSelected()) { mainPanel.remove(1); mainPanel.add(textView, BorderLayout.CENTER); mainPanel.repaint(); } } }); ButtonGroup group = new ButtonGroup(); group.add(textViewButton); group.add(graphViewButton); JPanel togglePanel = new JPanel(new FlowLayout(FlowLayout.LEFT)); togglePanel.add(textViewButton); togglePanel.add(graphViewButton); mainPanel.add(togglePanel, BorderLayout.NORTH); mainPanel.add(graphView, BorderLayout.CENTER); graphViewButton.setSelected(true); return mainPanel; } public Component getVisualizationComponent(IOContainer container) { if (classifier instanceof Drawable) { try { Drawable drawable = (Drawable) classifier; int graphType = drawable.graphType(); switch (graphType) { case Drawable.TREE: Component treeView = new TreeVisualizer(new TreeDisplayListener() { public void userCommand(TreeDisplayEvent e) {} }, drawable.graph(), new PlaceNode2()); return createTextAndGraphView(super.getVisualizationComponent(container), treeView); case Drawable.BayesNet: GraphVisualizer visualizer = new GraphVisualizer(); // remove graph tool bar from original location (NORTH) JToolBar graphTools = (JToolBar)visualizer.getComponent(0); visualizer.remove(graphTools); // remove progress bar from tool bar graphTools.remove(graphTools.getComponentCount() - 1); // add tool bar to new location (WEST) JPanel toolPanel = new JPanel(new BorderLayout()); toolPanel.add(graphTools, BorderLayout.NORTH); visualizer.add(toolPanel, BorderLayout.WEST); // init graph visualizer.readBIF(drawable.graph()); visualizer.layoutGraph(); return createTextAndGraphView(super.getVisualizationComponent(container), visualizer); case Drawable.NOT_DRAWABLE: default: return super.getVisualizationComponent(container); } } catch (Exception e) { return super.getVisualizationComponent(container); } } else { return super.getVisualizationComponent(container); } } public boolean equals(Object o) { if (!super.equals(o)) return false; WekaClassifier other = (WekaClassifier) o; if (!other.classifier.equals(this.classifier)) return false; return true; } public int hashCode() { return this.classifier.hashCode(); } }