/* * RapidMiner * * Copyright (C) 2001-2011 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.tools.math.distribution.kernel; import java.util.TreeSet; import com.rapidminer.tools.math.distribution.NormalDistribution; /** * An updatable estimated kernel density distribution. Update strategy greedily * creates and merges kernels and assigns value to these kernels. * * @author Tobias Malbrecht */ public class GreedyKernelDistribution extends KernelDistribution { public static final long serialVersionUID = -3298190542815818L; private static final double DEFAULT_MINIMUM_BANDWIDTH = 0.1; private static final int DEFAULT_NUMBER_OF_KERNELS = 10; private int numberOfKernels; private double minBandwidth; private TreeSet<NormalKernel> kernels; public GreedyKernelDistribution() { this(DEFAULT_MINIMUM_BANDWIDTH, DEFAULT_NUMBER_OF_KERNELS); } public GreedyKernelDistribution(double minBandwidth, int numberOfKernels) { this.numberOfKernels = numberOfKernels; this.minBandwidth = minBandwidth; kernels = new TreeSet<NormalKernel>(); } public void update(double value, double weight) { if (!Double.isNaN(value) && !Double.isNaN(weight)) { boolean kernelUpdated = false; double bestAssignmentDistance = Double.POSITIVE_INFINITY; double bestMergeDistance = Double.POSITIVE_INFINITY; NormalKernel bestAssignmentKernel = null; NormalKernel lastKernel = null; NormalKernel bestMergeKernel1 = null; NormalKernel bestMergeKernel2 = null; for (NormalKernel kernel : kernels) { double assignmentDistance = Math.abs(value - kernel.getMean()); if (assignmentDistance == 0) { kernel.update(value, weight); kernelUpdated = true; break; } if (assignmentDistance < bestAssignmentDistance) { bestAssignmentDistance = assignmentDistance; bestAssignmentKernel = kernel; } if (lastKernel != null) { double mergeDistance = Math.abs(lastKernel.getMean() - kernel.getMean()); if (mergeDistance < bestMergeDistance) { bestMergeDistance = mergeDistance; bestMergeKernel1 = lastKernel; bestMergeKernel2 = kernel; } } lastKernel = kernel; } if (!kernelUpdated) { if (kernels.size() < numberOfKernels) { NormalKernel kernel = new NormalKernel(minBandwidth); kernel.update(value, weight); kernels.add(kernel); } else { if (bestAssignmentDistance < bestMergeDistance) { bestAssignmentKernel.update(value, weight); } else { bestMergeKernel1.update(bestMergeKernel2); kernels.remove(bestMergeKernel2); NormalKernel kernel = new NormalKernel(minBandwidth); kernel.update(value, weight); kernels.add(kernel); } } } } } public void update(double value) { update(value, 1.0d); } @Override public String getAttributeName() { return null; } @Override public int getNumberOfParameters() { return 0; } @Override public String getParameterName(int index) { return null; } @Override public double getParameterValue(int index) { return Double.NaN; } @Override public double getUpperBound() { double maxMean = Double.NEGATIVE_INFINITY; double maxStandardDeviation = DEFAULT_BANDWIDTH; for (NormalKernel kernel : kernels) { double mean = kernel.getMean(); double standardDeviation = kernel.getStandardDeviation(); if (mean > maxMean) { maxMean = mean; } if (standardDeviation > maxStandardDeviation) { maxStandardDeviation = standardDeviation; } } return NormalDistribution.getUpperBound(maxMean, maxStandardDeviation); } @Override public double getLowerBound() { double minMean = Double.POSITIVE_INFINITY; double maxStandardDeviation = DEFAULT_BANDWIDTH; for (NormalKernel kernel : kernels) { double mean = kernel.getMean(); double standardDeviation = kernel.getStandardDeviation(); if (mean < minMean) { minMean = mean; } if (standardDeviation > maxStandardDeviation) { maxStandardDeviation = standardDeviation; } } return NormalDistribution.getLowerBound(minMean, maxStandardDeviation); } public double getTotalWeight() { double totalWeight = 0; for (NormalKernel kernel : kernels) { totalWeight += kernel.getTotalWeight(); } return totalWeight; } @Override public double getProbability(double value) { double probability = 0; double totalWeight = 0; for (NormalKernel kernel : kernels) { probability += kernel.getTotalWeight() * kernel.getProbability(value); totalWeight += kernel.getTotalWeight(); } return probability / totalWeight; } }