/* * RapidMiner * * Copyright (C) 2001-2008 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.learner.lazy; import java.util.Collections; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import com.rapidminer.example.Attribute; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.example.Statistics; import com.rapidminer.operator.learner.PredictionModel; import com.rapidminer.operator.learner.clustering.IdUtils; import com.rapidminer.operator.similarity.SimilarityMeasure; import com.rapidminer.operator.similarity.attributebased.ExampleBasedSimilarityMeasure; import com.rapidminer.tools.WeightedObject; /** * A simple implementation of a knn model. * * @author Michael Wurst, Ingo Mierswa * @version $Id: KNNModel.java,v 1.6 2008/05/09 19:23:24 ingomierswa Exp $ * */ public class KNNModel extends PredictionModel { private static final long serialVersionUID = -6292869962412072573L; private ExampleSet trainingSet; private Attribute weight; private int k; private boolean weightedVote; private SimilarityMeasure similarity; private double majorityPrediction; /** * Create a knn model. * * @param trainingSet the example set * @param k max. number of neighbors * @param weightedVote weight votes with similarity * @param similarity the similarity measure */ public KNNModel(ExampleSet trainingSet, SimilarityMeasure similarity, int k, boolean weightedVote) { super(trainingSet); this.weight = trainingSet.getAttributes().getWeight(); this.trainingSet = trainingSet; this.k = k; this.weightedVote = weightedVote; this.similarity = similarity; this.trainingSet.recalculateAttributeStatistics(trainingSet.getAttributes().getLabel()); if (trainingSet.getAttributes().getLabel().isNominal()) { majorityPrediction = trainingSet.getStatistics(trainingSet.getAttributes().getLabel(), Statistics.MODE); } else { majorityPrediction = trainingSet.getStatistics(trainingSet.getAttributes().getLabel(), Statistics.AVERAGE); } } public ExampleSet performPrediction(ExampleSet testSet, Attribute predictedLabel) { for (Example e : testSet) { // determine neighbors String id = IdUtils.getIdFromExample(e); List<WeightedObject<Example>> allNeighbors = new LinkedList<WeightedObject<Example>>(); for (Example trainingExample : trainingSet) { String trainingId = IdUtils.getIdFromExample(trainingExample); double similarityValue = Double.NaN; if (similarity instanceof ExampleBasedSimilarityMeasure) { similarityValue = ((ExampleBasedSimilarityMeasure) similarity).similarity(e, trainingExample); } else if ((id != null) && (trainingId != null)) if (similarity.isSimilarityDefined(id, trainingId)) similarityValue = similarity.similarity(id, trainingId); if (!Double.isNaN(similarityValue)) { if (similarity.isDistance()) allNeighbors.add(new WeightedObject<Example>(trainingExample, -similarityValue)); else allNeighbors.add(new WeightedObject<Example>(trainingExample, similarityValue)); } else { allNeighbors.add(new WeightedObject<Example>(trainingExample, Double.NEGATIVE_INFINITY)); } } Collections.sort(allNeighbors); int actualK = Math.min(k, allNeighbors.size()); List<WeightedObject<Example>> neighbors = allNeighbors.subList(allNeighbors.size() - actualK, allNeighbors.size()); // perform classification or regression if (getLabel().isNominal()) { // classification Map<String, Double> counter = new HashMap<String, Double>(); double totalSum = 0.0d; for (WeightedObject<Example> weightedNeighbor : neighbors) { Example neighbor = weightedNeighbor.getObject(); String labelValue = getLabel().getMapping().mapIndex((int)neighbor.getLabel()); double labelSum = 0.0d; if (counter.get(labelValue) != null) { labelSum = counter.get(labelValue); } double exampleWeight = 1.0d; if (weight != null) exampleWeight = neighbor.getValue(weight); double similarityWeight = 1.0d; if (weightedVote) { if (!similarity.isDistance()) similarityWeight = weightedNeighbor.getWeight(); else similarityWeight = 1.0d - (-weightedNeighbor.getWeight() / (1.0d - weightedNeighbor.getWeight())); } double currentWeight = exampleWeight * similarityWeight; labelSum += currentWeight; totalSum += currentWeight; counter.put(labelValue, labelSum); } // calculate confidences and best class String bestClass = null; double best = Double.NEGATIVE_INFINITY; for (String labelValue : getLabel().getMapping().getValues()) { Double sumObject = counter.get(labelValue); if (sumObject == null) { e.setConfidence(labelValue, 0.0d); } else { e.setConfidence(labelValue, sumObject / totalSum); if (sumObject > best) { best = sumObject; bestClass = labelValue; } } } // set crisp prediction if (bestClass != null) { e.setPredictedLabel(predictedLabel.getMapping().mapString(bestClass)); } else { e.setPredictedLabel(majorityPrediction); } } else { // regression double labelSum = 0.0d; double totalSum = 0.0d; for (WeightedObject<Example> weightedNeighbor : neighbors) { Example neighbor = weightedNeighbor.getObject(); double exampleWeight = 1.0d; if (weight != null) exampleWeight = neighbor.getValue(weight); if (!weightedVote) { totalSum += exampleWeight; labelSum += exampleWeight * neighbor.getLabel(); } else { labelSum += exampleWeight * (neighbor.getLabel() * (1.0d - (-weightedNeighbor.getWeight() / (1 - weightedNeighbor.getWeight())))); totalSum += exampleWeight * ((1.0d - (-weightedNeighbor.getWeight() / (1.0d - weightedNeighbor.getWeight())))); } } if (totalSum > 0.0d) e.setPredictedLabel(labelSum / totalSum); else e.setPredictedLabel(majorityPrediction); } } return testSet; } }