/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* BPMLL.java
* Copyright (C) 2009-2010 Aristotle University of Thessaloniki, Thessaloniki, Greece
*/
package mulan.classifier.neural;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.Set;
import mulan.classifier.InvalidDataException;
import mulan.classifier.MultiLabelLearnerBase;
import mulan.classifier.MultiLabelOutput;
import mulan.classifier.neural.model.ActivationTANH;
import mulan.classifier.neural.model.BasicNeuralNet;
import mulan.classifier.neural.model.NeuralNet;
import mulan.core.WekaException;
import mulan.data.DataUtils;
import mulan.data.InvalidDataFormatException;
import mulan.data.MultiLabelInstances;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NominalToBinary;
/**
* The implementation of Back-Propagation Multi-Label Learning (BPMLL) learner.
* The learned model is stored in {@link NeuralNet} neural network. The models of the
* learner is built by {@link BPMLLAlgorithm} from given training data set.
*
* <!-- technical-bibtex-start -->
*
* <!-- technical-bibtex-end -->
*
* @author Jozef Vilcek
* @see BPMLLAlgorithm
*/
public class BPMLL extends MultiLabelLearnerBase {
private static final long serialVersionUID = 2153814250172139021L;
private static final double NET_BIAS = 1;
private static final double ERROR_SMALL_CHANGE = 0.000001;
// filter used to convert nominal input attributes into binary-numeric
private NominalToBinary nominalToBinaryFilter;
// algorithm parameters
private int epochs = 100;
private final Long randomnessSeed;
private double weightsDecayCost = 0.00001;
private double learningRate = 0.05;
private int[] hiddenLayersTopology;
// members related to normalization or attributes
private boolean normalizeAttributes = true;
private NormalizationFilter normalizer;
private NeuralNet model;
private ThresholdFunction thresholdF;
/**
* Creates a new instance of {@link BPMLL} learner.
*/
public BPMLL() {
randomnessSeed = null;
}
/**
* Creates a new instance of {@link BPMLL} learner.
* @param randomnessSeed the seed value for pseudo-random generator
*/
public BPMLL(long randomnessSeed) {
this.randomnessSeed = randomnessSeed;
}
/**
* Sets the topology of hidden layers for neural network.
* The length of passed array defines number of hidden layers.
* The value at particular index of array defines number of neurons in that layer.
* If <code>null</code> is specified, no hidden layers will be created.
* <br/>
* The network is created when learner is being built.
* The input and output layer is determined from input training data.
*
* @param hiddenLayers
* @throws IllegalArgumentException if any value in the array is less or equal to zero
*/
public void setHiddenLayers(int[] hiddenLayers) {
if (hiddenLayers != null) {
for (int value : hiddenLayers) {
if (value <= 0) {
throw new IllegalArgumentException("Invalid hidden layer topology definition. " +
"Number of neurons in hidden layer must be larger than zero.");
}
}
}
hiddenLayersTopology = hiddenLayers;
}
/**
* Gets an array defining topology of hidden layer of the underlying neural model.
*
* @return The method returns a copy of the array.
*/
public int[] getHiddenLayers() {
return hiddenLayersTopology == null ? hiddenLayersTopology : Arrays.copyOf(hiddenLayersTopology, hiddenLayersTopology.length);
}
/**
* Sets the learning rate. Must be greater than 0 and no more than 1.<br/>
* Default value is 0.05.
*
* @param learningRate the learning rate
* @throws IllegalArgumentException if passed value is invalid
*/
public void setLearningRate(double learningRate) {
if (learningRate <= 0 || learningRate > 1) {
throw new IllegalArgumentException("The learning rate must be greater than 0 and no more than 1. " +
"Entered value is : " + learningRate);
}
this.learningRate = learningRate;
}
/**
* Gets the learning rate. The default value is 0.05.
* @return learning rate
*/
public double getLearningRate() {
return learningRate;
}
/**
* Sets the regularization cost term for weights decay.
* Must be greater than 0 and no more than 1.<br/>
* Default value is 0.00001.
*
* @param weightsDecayCost the weights decay cost term
* @throws IllegalArgumentException if passed value is invalid
*/
public void setWeightsDecayRegularization(double weightsDecayCost) {
if (weightsDecayCost <= 0 || weightsDecayCost > 1) {
throw new IllegalArgumentException("The weights decay regularization cost " +
"term must be greater than 0 and no more than 1. " +
"The passed value is : " + weightsDecayCost);
}
this.weightsDecayCost = weightsDecayCost;
}
/**
* Gets a value of the regularization cost term for weights decay.
* @return regularization cost
*/
public double getWeightsDecayRegularization() {
return weightsDecayCost;
}
/**
* Sets the number of training epochs. Must be greater than 0.<br/>
* Default value is 100.
*
* @param epochs the number of training epochs
* @throws IllegalArgumentException if passed value is invalid
*/
public void setTrainingEpochs(int epochs) {
if (epochs <= 0) {
throw new IllegalArgumentException("The number of training epochs must be greater than zero. " +
"Entered value is : " + epochs);
}
this.epochs = epochs;
}
/**
* Gets number of training epochs.
* Default value is 100.
* @return training epochs
*/
public int getTrainingEpochs() {
return epochs;
}
/**
* Sets whether attributes of instances data (except label attributes) should
* be normalized prior to building the learner. Normalization is performed
* on numeric attributes to the range {-1,1}).<br/>
* When making prediction, attributes of passed input instance are also
* normalized prior to making prediction.<br/>
* Default is true (normalization of attributes takes place).
*
* @param normalize flag if normalization of attributes should be used
* @throws IllegalArgumentException if passed value is invalid
*/
public void setNormalizeAttributes(boolean normalize) {
normalizeAttributes = normalize;
}
/**
* Gets a value if normalization of nominal attributes should take place.
* Default value is true.
* @return a value if normalization of nominal attributes should take place
*/
public boolean getNormalizeAttributes() {
return normalizeAttributes;
}
protected void buildInternal(final MultiLabelInstances instances) throws Exception {
// delete filter if available from previous build, a new one will be created if necessary
nominalToBinaryFilter = null;
MultiLabelInstances trainInstances = instances.clone();
List<DataPair> trainData = prepareData(trainInstances);
int inputsDim = trainData.get(0).getInput().length;
model = buildNeuralNetwork(inputsDim);
BPMLLAlgorithm learnAlg = new BPMLLAlgorithm(model, weightsDecayCost);
int numInstances = trainData.size();
int processedInstances = 0;
double prevError = Double.MAX_VALUE;
double error = 0;
for (int epoch = 0; epoch < epochs; epoch++) {
Collections.shuffle(trainData, new Random(1));
for (int index = 0; index < numInstances; index++) {
DataPair trainPair = trainData.get(index);
double result = learnAlg.learn(trainPair.getInput(), trainPair.getOutput(), learningRate);
if (!Double.isNaN(result)) {
error += result;
processedInstances++;
}
}
if (getDebug()) {
if (epoch % 10 == 0) {
debug("Training epoch : " + epoch + " Model error : " + error / processedInstances);
}
}
double errorDiff = prevError - error;
if (errorDiff <= ERROR_SMALL_CHANGE * prevError) {
if (getDebug()) {
debug("Global training error does not decrease enough. Training terminated.");
}
break;
}
}
thresholdF = buildThresholdFunction(trainData);
}
@Override
public TechnicalInformation getTechnicalInformation() {
TechnicalInformation technicalInfo = new TechnicalInformation(Type.ARTICLE);
technicalInfo.setValue(Field.AUTHOR, "Zhang, M.L., Zhou, Z.H.");
technicalInfo.setValue(Field.YEAR, "2006");
technicalInfo.setValue(Field.TITLE, "Multi-label neural networks with applications to functional genomics and text categorization");
technicalInfo.setValue(Field.JOURNAL, "IEEE Transactions on Knowledge and Data Engineering");
technicalInfo.setValue(Field.VOLUME, "18");
technicalInfo.setValue(Field.PAGES, "1338-1351");
return technicalInfo;
}
private ThresholdFunction buildThresholdFunction(List<DataPair> trainData) {
int numExamples = trainData.size();
double[][] idealLabels = new double[numExamples][numLabels];
double[][] modelConfidences = new double[numExamples][numLabels];
for (int example = 0; example < numExamples; example++) {
DataPair dataPair = trainData.get(example);
idealLabels[example] = dataPair.getOutput();
modelConfidences[example] = model.feedForward(dataPair.getInput());
}
return new ThresholdFunction(idealLabels, modelConfidences);
}
private NeuralNet buildNeuralNetwork(int inputsDim) {
int[] networkTopology;
if (hiddenLayersTopology == null) {
int hiddenUnits = Math.round(0.2f * inputsDim);
hiddenLayersTopology = new int[]{hiddenUnits};
networkTopology = new int[]{inputsDim, hiddenUnits, numLabels};
} else {
networkTopology = new int[hiddenLayersTopology.length + 2];
networkTopology[0] = inputsDim;
System.arraycopy(hiddenLayersTopology, 0, networkTopology, 1, hiddenLayersTopology.length);
networkTopology[networkTopology.length - 1] = numLabels;
}
NeuralNet aModel = new BasicNeuralNet(networkTopology, NET_BIAS, ActivationTANH.class, randomnessSeed == null ? null : new Random(randomnessSeed));
return aModel;
}
/**
* Prepares {@link MultiLabelInstances} data for the learning algorithm.
* <br/>
* The data are checked for correct format, label attributes
* are converted to bipolar values. Finally {@link Instance} instances are
* converted to {@link DataPair} instances, which will be used for the algorithm.
*/
private List<DataPair> prepareData(MultiLabelInstances mlData) {
Instances data = mlData.getDataSet();
data = checkAttributesFormat(data, mlData.getFeatureAttributes());
if (data == null) {
throw new InvalidDataException("Attributes are not in correct format. " +
"Input attributes (all but the label attributes) must be nominal or numeric.");
} else {
try {
mlData = mlData.reintegrateModifiedDataSet(data);
this.labelIndices = mlData.getLabelIndices();
} catch (InvalidDataFormatException e) {
throw new InvalidDataException("Failed to create a multilabel data set from modified instances.");
}
if (normalizeAttributes) {
normalizer = new NormalizationFilter(mlData, true, -0.8, 0.8);
}
return DataPair.createDataPairs(mlData, true);
}
}
/**
* Checks {@link Instances} data if attributes (all but the label attributes)
* are numeric or nominal. Nominal attributes are transformed to binary by use of
* {@link NominalToBinary} filter.
*
* @param dataSet instances data to be checked
* @param inputAttributes input/feature attributes which format need to be checked
* @return data set if it passed checks; otherwise <code>null</code>
*/
private Instances checkAttributesFormat(Instances dataSet, Set<Attribute> inputAttributes) {
StringBuilder nominalAttrRange = new StringBuilder();
String rangeDelimiter = ",";
for (Attribute attribute : inputAttributes) {
if (attribute.isNumeric() == false) {
if (attribute.isNominal()) {
nominalAttrRange.append((attribute.index() + 1) + rangeDelimiter);
} else {
// fail check if any other attribute type than nominal or numeric is used
return null;
}
}
}
// convert any nominal attributes to binary
if (nominalAttrRange.length() > 0) {
nominalAttrRange.deleteCharAt(nominalAttrRange.lastIndexOf(rangeDelimiter));
try {
nominalToBinaryFilter = new NominalToBinary();
nominalToBinaryFilter.setAttributeIndices(nominalAttrRange.toString());
nominalToBinaryFilter.setInputFormat(dataSet);
dataSet = Filter.useFilter(dataSet, nominalToBinaryFilter);
} catch (Exception exception) {
nominalToBinaryFilter = null;
if (getDebug()) {
debug("Failed to apply NominalToBinary filter to the input instances data. " +
"Error message: " + exception.getMessage());
}
throw new WekaException("Failed to apply NominalToBinary filter to the input instances data.", exception);
}
}
return dataSet;
}
public MultiLabelOutput makePredictionInternal(Instance instance) throws InvalidDataException {
Instance inputInstance = null;
if (nominalToBinaryFilter != null) {
try {
nominalToBinaryFilter.input(instance);
inputInstance = nominalToBinaryFilter.output();
inputInstance.setDataset(null);
} catch (Exception ex) {
throw new InvalidDataException("The input instance for prediction is invalid. " +
"Instance is not consistent with the data the model was built for.");
}
} else {
inputInstance = DataUtils.createInstance(instance, instance.weight(), instance.toDoubleArray());
}
int numAttributes = inputInstance.numAttributes();
if (numAttributes < model.getNetInputSize()) {
throw new InvalidDataException("Input instance do not have enough attributes " +
"to be processed by the model. Instance is not consistent with the data the model was built for.");
}
// if instance has more attributes than model input, we assume that true outputs
// are there, so we remove them
List<Integer> someLabelIndices = new ArrayList<Integer>();
boolean labelsAreThere = false;
if (numAttributes > model.getNetInputSize()) {
for (int index : this.labelIndices) {
someLabelIndices.add(index);
}
labelsAreThere = true;
}
if (normalizeAttributes) {
normalizer.normalize(inputInstance);
}
int inputDim = model.getNetInputSize();
double[] inputPattern = new double[inputDim];
int indexCounter = 0;
for (int attrIndex = 0; attrIndex < numAttributes; attrIndex++) {
if (labelsAreThere && someLabelIndices.contains(attrIndex)) {
continue;
}
inputPattern[indexCounter] = inputInstance.value(attrIndex);
indexCounter++;
}
double[] labelConfidences = model.feedForward(inputPattern);
double threshold = thresholdF.computeThreshold(labelConfidences);
boolean[] labelPredictions = new boolean[numLabels];
Arrays.fill(labelPredictions, false);
for (int labelIndex = 0; labelIndex < numLabels; labelIndex++) {
if (labelConfidences[labelIndex] > threshold) {
labelPredictions[labelIndex] = true;
}
// translate from bipolar output to binary
labelConfidences[labelIndex] = (labelConfidences[labelIndex] + 1) / 2;
}
MultiLabelOutput mlo = new MultiLabelOutput(labelPredictions, labelConfidences);
return mlo;
}
}