/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * BPMLLAlgorithm.java * Copyright (C) 2009-2010 Aristotle University of Thessaloniki, Thessaloniki, Greece */ package mulan.classifier.neural; import java.util.ArrayList; import java.util.List; import mulan.classifier.neural.model.ActivationFunction; import mulan.classifier.neural.model.NeuralNet; import mulan.classifier.neural.model.Neuron; /** * The implementation of Back-Propagation Multi-Label Learning (BPMLL) algorithm for neural networks. * The algorithm uses weights decay regularization to avoid overfitting. * <br/><br/> * For more information see: * <br/> * Zhang, M.L., Zhou, Z.H.: Multi-label neural networks with applications to functional genomics * and text categorization. IEEE Transactions on Knowledge and Data Engineering 18 (2006) 1338-1351 * * @author Jozef Vilcek * @see NeuralNet */ public class BPMLLAlgorithm { private final NeuralNet neuralNet; private final double weightsDecayCost; /** * Creates a {@link BPMLLAlgorithm} instance. * * @param neuralNet the neural network model to learn * @param weightsDecayCost the weights decay cost term used for regularization. * The value must be greater than 0 and no more than 1. */ public BPMLLAlgorithm(NeuralNet neuralNet, double weightsDecayCost) { if (neuralNet == null) { throw new IllegalArgumentException("The passed neural network model is null."); } if (weightsDecayCost <= 0 || weightsDecayCost > 1) { throw new IllegalArgumentException("The weights decay regularization cost term must be greater " + "than 0 and no more than 1. The passed value is : " + weightsDecayCost); } this.neuralNet = neuralNet; this.weightsDecayCost = weightsDecayCost; } /** * Returns the neural network which is learned/updated by the algorithm. * * @return the neural network */ public NeuralNet getNetwork() { return neuralNet; } /** * Returns the value of weights decay cost term used for regularization. * * @return the weights decay cost term */ public double getWeightsDecayCost() { return weightsDecayCost; } /** * Performs one learning step with given input pattern and expected output values. * The function outputs the error for passed input pattern.<br/> * The input is ignored by the algorithm (can not process) if the input example has * assigned either all or non of the labels. * In this case, the function returns {@link Double#NaN}. * * @param inputPattern the input pattern for the network * @param expectedLabels the ideal, expected values the network should output as a * response for the given input. If the i-th label class belongs to the input pattern * instance, then i-th value is +1, otherwise the value is -1. * @param learningRate the learning rate used to update the neural network weights * @return the error of the network response for the passed input * or {@link Double#NaN} if the passed input can not be processed. */ public double learn(double[] inputPattern, double[] expectedLabels, double learningRate) { if (inputPattern == null || inputPattern.length != neuralNet.getNetInputSize()) { throw new IllegalArgumentException("Specified input pattern vector is null " + "or does not match the input dimension of underlying neural network model."); } if (expectedLabels == null || expectedLabels.length != neuralNet.getNetOutputSize()) { throw new IllegalArgumentException("Specified expected labels vector is null " + "or does not match the output dimension of underlying neural network model."); } // 1. PROPAGATE SIGNAL double[] networkOutputs = neuralNet.feedForward(inputPattern); double[] outputErrors = computeErrorsForNeurons(networkOutputs, expectedLabels); if (outputErrors == null) { return Double.NaN; } double weightsSquareSum = 0; // 2. UPDATE WIGHTS - error back-propagation int layersCount = neuralNet.getLayersCount(); for (int layerIndex = layersCount - 1; layerIndex > 0; layerIndex--) { // 2a. COMPUTE ERROR TERMS List<Neuron> layer = neuralNet.getLayerUnits(layerIndex); if (layerIndex == layersCount - 1) { computeOutputLayerErrorTerms(layer, outputErrors); } else { List<Neuron> nextLayer = neuralNet.getLayerUnits(layerIndex + 1); computeHiddenLayerErrorTerms(layer, nextLayer); } // 2b. GET OUTPUTS OF NEXT LAYER (from back-propagation perspective) List<Neuron> previousLayer = neuralNet.getLayerUnits(layerIndex - 1); double[] previousLayerOut = new double[previousLayer.size()]; int previousLayerSize = previousLayer.size(); for (int n = 0; n < previousLayerSize; n++) { previousLayerOut[n] = previousLayer.get(n).getOutput(); } // compute sum of weights squares for weights decay regularization for (Neuron neuron : layer) { double[] weights = neuron.getWeights(); for (double weight : weights) { weightsSquareSum += weight * weight; } } // 2c. UPDATE WEIGHTS OF THE LAYER updateWeights(layer, previousLayerOut, learningRate); } double globalError = 0; for (double error : outputErrors) { globalError += Math.abs(error); } globalError += weightsDecayCost * 0.5 * weightsSquareSum; return globalError; } /** * Returns the error of the neural network for given input. This is value of error function * computed from network output value for given input and expected, * ideal output for given input.<br/> * The input is ignored by the algorithm (can not process), if the input example has * assigned either all or non of labels. * In this case, the function returns {@link Double#NaN}. * * @param inputPattern the input pattern to be processed * @param expectedLabels the ideal, expected values the network should output as a * response for the given input. If the ith label class belongs to the input pattern * instance, then ith value is +1, otherwise the value is -1. * @return the error of the network response for the passed input * or {@link Double#NaN} if the passed input can not be processed */ public double getNetworkError(double[] inputPattern, double[] expectedLabels) { double[] networkOutputs = neuralNet.feedForward(inputPattern); double[] outputErrors = computeErrorsForNeurons(networkOutputs, expectedLabels); if (outputErrors == null) { return Double.NaN; } // compute sum of weights squares for weights decay regularization double weightsSquareSum = 0; int layersCount = neuralNet.getLayersCount(); for (int layerIndex = 1; layerIndex < layersCount; layerIndex++) { List<Neuron> layer = neuralNet.getLayerUnits(layerIndex); for (Neuron neuron : layer) { double[] weights = neuron.getWeights(); for (double weight : weights) { weightsSquareSum += weight * weight; } } } double globalError = 0; for (double error : outputErrors) { globalError += Math.abs(error); } globalError += weightsDecayCost * 0.5 * weightsSquareSum; return globalError; } private void updateWeights(List<Neuron> layer, double[] layerInputs, double learningRate) { // w(t+1) = w(t) + a*dw(t) + m*w(t-1) ... dw(t) = e(t)*in(t) int layerSize = layer.size(); for (int n = 0; n < layerSize; n++) { Neuron neuron = layer.get(n); double[] weights = neuron.getWeights(); double error = neuron.getError(); int inputsCount = layerInputs.length; double currentDelta = 0; for (int i = 0; i < inputsCount; i++) { currentDelta = learningRate * error * layerInputs[i]; weights[i] += currentDelta - weightsDecayCost * weights[i]; } // update bias weight (bias input is assumed to be +1) currentDelta = learningRate * error * neuron.getBiasInput(); weights[inputsCount] += currentDelta - weightsDecayCost * weights[inputsCount]; } } private void computeOutputLayerErrorTerms(List<Neuron> outLayer, double[] outputErrors) { int neuronsInLayer = outLayer.size(); for (int n = 0; n < neuronsInLayer; n++) { Neuron neuron = outLayer.get(n); ActivationFunction layerFunction = neuron.getActivationFunction(); double errorTerm = outputErrors[n] * layerFunction.derivative(neuron.getNeuronInput()); neuron.setError(errorTerm); } } private void computeHiddenLayerErrorTerms(List<Neuron> layer, List<Neuron> nextLayer) { int neuronsInLayer = layer.size(); int nextLayerNeuronsCount = nextLayer.size(); for (int n = 0; n < neuronsInLayer; n++) { Neuron neuron = layer.get(n); double sum = 0; for (int k = 0; k < nextLayerNeuronsCount; k++) { Neuron nextNeuron = nextLayer.get(k); double[] nextNeuronWeights = nextNeuron.getWeights(); sum += nextNeuron.getError() * nextNeuronWeights[n]; } ActivationFunction neuronFunction = neuron.getActivationFunction(); double errorTerm = sum * neuronFunction.derivative(neuron.getNeuronInput()); neuron.setError(errorTerm); } } /** * Computes errors for each output neurons separately according formula: <br/><br/> * * Ei = --- (1/|Yi|*|Yi'|)*SUM{exp(-(Ci - Cl))} ... if ith is from Yi set (is label) * | where l is from Yi' * | * |-- (-1/|Yi|*|Yi'|)*SUM{exp(-(Ck - Ci))} ... if ith is from Yi' set (is not label) * where k is from Yi * * Note that these are not error terms used in network weights updates. * * @param networkOutputs the output of the network, which represents network belief for labels assignment * @param expectedLabels the ideal, expected output for labels assignment which network should output * @return error for each neuron or null if can not be computed (either Yi or Yi' is empty set) */ private double[] computeErrorsForNeurons(double[] networkOutputs, double[] expectedLabels) { List<Integer> isLabel = new ArrayList<Integer>(); List<Integer> isNotLabel = new ArrayList<Integer>(); int labelsCount = expectedLabels.length; for (int index = 0; index < labelsCount; index++) { if (expectedLabels[index] == 1) { isLabel.add(index); } else { isNotLabel.add(index); } } // compute error terms for output neurons double[] neuronsErrors = null; if (isLabel.size() != 0 && isNotLabel.size() != 0) { neuronsErrors = new double[labelsCount]; for (int index = 0; index < labelsCount; index++) { double error = 0; if (isLabel.contains(index)) { for (int isNotLabelIndex : isNotLabel) { error += Math.exp(-(networkOutputs[index] - networkOutputs[isNotLabelIndex])); } } else { for (int isLabelIndex : isLabel) { error -= Math.exp(-(networkOutputs[isLabelIndex] - networkOutputs[index])); } } error *= 1.0 / (isLabel.size() * isNotLabel.size()); neuronsErrors[index] = error; } } return neuronsErrors; } }