/* * RapidMiner * * Copyright (C) 2001-2011 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.learner.bayes; import Jama.Matrix; import com.rapidminer.example.Attribute; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.operator.OperatorException; import com.rapidminer.operator.learner.SimplePredictionModel; import com.rapidminer.tools.Tools; /** * This is the model for discriminant analysis based learning schemes. * * @author Sebastian Land */ public class DiscriminantModel extends SimplePredictionModel { private static final long serialVersionUID = 3793343069512113817L; private double alpha; private String[] labels; private Matrix[] meanVectors; private Matrix[] inverseCovariances; private double[] aprioriProbabilities; private double[] constClassValues; public DiscriminantModel(ExampleSet exampleSet, String[] labels, Matrix[] meanVectors, Matrix[] inverseCovariances, double[] aprioriProbabilities, double alpha) { super(exampleSet); this.alpha = alpha; this.labels = labels; this.meanVectors = meanVectors; this.inverseCovariances = inverseCovariances; this.aprioriProbabilities = aprioriProbabilities; this.constClassValues = new double[labels.length]; for (int i = 0; i < labels.length; i++) { constClassValues[i] = - 0.5d * meanVectors[i].times(inverseCovariances[i]).times(meanVectors[i].transpose()).get(0, 0) + Math.log(aprioriProbabilities[i]); } } @Override public double predict(Example example) throws OperatorException { int numberOfAttributes = meanVectors[0].getColumnDimension(); double[] vector = new double[numberOfAttributes]; int i = 0; for (Attribute attribute: example.getAttributes()) { if (attribute.isNumerical()) { vector[i] = example.getValue(attribute); i++; } } Matrix xVector = new Matrix(vector, 1); double[] labelFunction = new double[labels.length]; for (int labelIndex = 0; labelIndex < labels.length; labelIndex++) { labelFunction[labelIndex] = xVector.times(inverseCovariances[labelIndex]).times(meanVectors[labelIndex].transpose()).get(0,0) + constClassValues[labelIndex]; } double maximalValue = Double.NEGATIVE_INFINITY; int bestValue = 0; for (int labelIndex = 0; labelIndex < labels.length; labelIndex++) { if (labelFunction[labelIndex] >= maximalValue) { bestValue = labelIndex; maximalValue = labelFunction[labelIndex]; } } return bestValue; } @Override public String getName() { if (alpha == 0d) return "Quadratic Discriminant Model"; else if (alpha == 1d) return "Linear Discriminant Model"; else return "Regularized Discriminant Model"; } @Override public String toString() { StringBuffer buffer = new StringBuffer(); buffer.append("Apriori probabilities:\n"); for (int i = 0; i < labels.length; i++) { buffer.append(labels[i] + "\t"); buffer.append(Tools.formatNumber(aprioriProbabilities[i], 4)+ "\n"); } return buffer.toString(); } }