/* * RapidMiner * * Copyright (C) 2001-2011 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.learner.meta; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map.Entry; import com.rapidminer.example.Attribute; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.operator.Model; import com.rapidminer.operator.OperatorException; import com.rapidminer.operator.learner.PredictionModel; /** */ @Deprecated public class HierarchicalModel extends PredictionModel implements MetaModel { public static class Node implements Serializable { private static final long serialVersionUID = 1L; private final String className; private final List<Node> children = new ArrayList<Node>(); private Node parent = null; private Model model = null; public Node(String className) { this.className = className; } public List<Node> getChildren() { return children; } public void addChild(Node child) { children.add(child); child.setParent(this); } public void setParent(Node parent) { this.parent = parent; } public Node getParent() { return this.parent; } public String getClassName() { return this.className; } public List<String> getChildrenClasses() { List<String> childrenClasses = new ArrayList<String>(); for (Node child : children) { childrenClasses.add(child.getClassName()); childrenClasses.addAll(child.getChildrenClasses()); } return childrenClasses; } public List<String> getLeaveClasses() { List<String> leaveClasses = new ArrayList<String>(); for (Node child : children) { leaveClasses.addAll(child.getLeaveClasses()); } if (children.size() == 0) { leaveClasses.add(className); } return leaveClasses; } public void setModel(Model model) { this.model = model; } public Model getModel() { return this.model; } } private static final long serialVersionUID = -5792943818860734082L; private final Node root; public HierarchicalModel(ExampleSet exampleSet, Node root) { super(exampleSet); this.root = root; } @Override public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws OperatorException { ExampleSet eSet = (ExampleSet) exampleSet.clone(); int numberOfClasses = getLabel().getMapping().getValues().size(); HashMap<String, Integer> classIndexMap = new HashMap<String, Integer>(numberOfClasses); for (String currentClass : getLabel().getMapping().getValues()) { classIndexMap.put(currentClass, getLabel().getMapping().mapString(currentClass)); } double[][] confidenceMatrix = new double[eSet.size()][numberOfClasses]; for (int i = 0; i < confidenceMatrix.length; i++) { for (int j = 0; j < confidenceMatrix[i].length; j++) { confidenceMatrix[i][j] = 1; } } performPrediction(eSet, predictedLabel, root, confidenceMatrix, classIndexMap); int counter = 0; for (Example example : exampleSet) { double predictedValue = 0; double maxConfidence = 0; // double sumConfidence = 0; // for (int i = 0; i < confidenceMatrix[counter].length; i++) { // sumConfidence += confidenceMatrix[counter][i]; // } for (Entry<String, Integer> entry : classIndexMap.entrySet()) { // confidenceMatrix[counter][entry.getValue()] /= sumConfidence; example.setConfidence(entry.getKey(), confidenceMatrix[counter][entry.getValue()]); if (confidenceMatrix[counter][entry.getValue()] > maxConfidence) { maxConfidence = confidenceMatrix[counter][entry.getValue()]; predictedValue = entry.getValue(); } } example.setPredictedLabel(predictedValue); counter++; } return exampleSet; } public void performPrediction(ExampleSet eSet, Attribute predictedLabel, Node node, double[][] confidenceMatrix, HashMap<String, Integer> classIndexMap) throws OperatorException { if (node.getModel() != null && node.getChildren().size() > 0) { System.err.println("Predicting " + node.getClassName()); eSet = node.getModel().apply(eSet); int counter = 0; for (Example example : eSet) { for (Node child : node.getChildren()) { double confidence = example.getConfidence(child.getClassName()); for (String className : child.getLeaveClasses()) { confidenceMatrix[counter][classIndexMap.get(className)] *= confidence; } } counter++; } PredictionModel.removePredictedLabel(eSet); } for (Node child : node.getChildren()) { performPrediction(eSet, predictedLabel, child, confidenceMatrix, classIndexMap); } } @Override public List<String> getModelNames() { List<String> names = new LinkedList<String>(); return names; } @Override public List<Model> getModels() { return Arrays.asList(); } }