/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * ThresholdFunction.java * Copyright (C) 2009-2010 Aristotle University of Thessaloniki, Thessaloniki, Greece */ package mulan.classifier.neural; import java.io.Serializable; import java.util.Arrays; import weka.core.Utils; import weka.core.matrix.Matrix; /** * Implementation of a threshold function. * * @author Jozef Vilcek */ public class ThresholdFunction implements Serializable { /** Default serial version UID for serialization*/ private static final long serialVersionUID = 5347411552628371402L; private double[] parameters; /** * Creates a new instance of {@link ThresholdFunction} and * builds the function based on input parameters. * * @param idealLabels the ideal output for each input patterns, which a model should output * @param modelOutLabels the real output of a model for each input pattern * @throws IllegalArgumentException if dimensions of input arrays does not match * @see ThresholdFunction#build(double[][], double[][]) */ public ThresholdFunction(final double[][] idealLabels, final double[][] modelOutLabels) { this.build(idealLabels, modelOutLabels); } /** * Computes a threshold value, based on learned parameters, for given labels confidences. * * @param labelsConfidences the labels confidences * @return the threshold value * @throws IllegalArgumentException if the dimension of labels confidences does not match * the dimension of learned parameters of threshold function. */ public double computeThreshold(final double[] labelsConfidences) { int expectedDim = parameters.length - 1; if (labelsConfidences.length != expectedDim) { throw new IllegalArgumentException("The array of label confidences has wrong dimension." + "The function expect parameters of length : " + expectedDim); } double threshold = 0; for (int index = 0; index < expectedDim; index++) { threshold += labelsConfidences[index] * parameters[index]; } threshold += parameters[expectedDim]; return threshold; } /** * Build a threshold function for based on input data. * The threshold function is build for a particular model. * * @param idealLabels the ideal output for each input patterns, which a model should output. * First index is expected to be number of examples and second is the label index. * @param modelOutLabels the real output of a model for each input pattern. * First index is expected to be number of examples and second is the label index. * @throws IllegalArgumentException if dimensions of input arrays does not match */ public void build(final double[][] idealLabels, final double[][] modelOutLabels) { if (idealLabels == null || modelOutLabels == null) { throw new IllegalArgumentException("Non of the input parameters can be null."); } int numExamples = idealLabels.length; int numLabels = idealLabels[0].length; if (modelOutLabels.length != numExamples || modelOutLabels[0].length != numLabels) { throw new IllegalArgumentException("Matrix dimensions of input parameters does not agree."); } double[] thresholds = new double[numExamples]; double[] isLabelModelOuts = new double[numLabels]; double[] isNotLabelModelOuts = new double[numLabels]; for (int example = 0; example < numExamples; example++) { Arrays.fill(isLabelModelOuts, Double.MAX_VALUE); Arrays.fill(isNotLabelModelOuts, -Double.MAX_VALUE); for (int label = 0; label < numLabels; label++) { if (idealLabels[example][label] == 1) { isLabelModelOuts[label] = modelOutLabels[example][label]; } else { isNotLabelModelOuts[label] = modelOutLabels[example][label]; } } double isLabelMin = isLabelModelOuts[Utils.minIndex(isLabelModelOuts)]; double isNotLabelMax = isNotLabelModelOuts[Utils.maxIndex(isNotLabelModelOuts)]; // check if we have unique minimum ... // if not take center of the segment ... if it is a segment if (isLabelMin != isNotLabelMax) { // check marginal cases -> all labels are in or none of them if (isLabelMin == Double.MAX_VALUE) { thresholds[example] = isNotLabelMax + 0.1; } else if (isNotLabelMax == -Double.MAX_VALUE) { thresholds[example] = isLabelMin - 0.1; } else { // center of a segment thresholds[example] = (isLabelMin + isNotLabelMax) / 2; } } else { // when minimum is unique thresholds[example] = isLabelMin; } } Matrix modelMatrix = new Matrix(numExamples, numLabels + 1, 1.0); modelMatrix.setMatrix(0, numExamples - 1, 0, numLabels - 1, new Matrix(modelOutLabels)); Matrix weights = modelMatrix.solve(new Matrix(thresholds, thresholds.length)); double[][] weightsArray = weights.transpose().getArray(); parameters = Arrays.copyOf(weightsArray[0], weightsArray[0].length); } /** * Returns parameters learned by the threshold function in last build. * Based on these parameters the functions is computing thresholds for * label confidences.<br/> * Support for unit tests ... * * @return parameters */ protected double[] getFunctionParameters() { return Arrays.copyOf(parameters, parameters.length); } }