/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * MMPLearner.java * Copyright (C) 2009-2010 Aristotle University of Thessaloniki, Thessaloniki, Greece * */ package mulan.classifier.neural; import java.util.ArrayList; import java.util.List; import java.util.Random; import java.util.Set; import mulan.classifier.InvalidDataException; import mulan.classifier.MultiLabelLearnerBase; import mulan.classifier.MultiLabelOutput; import mulan.classifier.neural.model.ActivationLinear; import mulan.classifier.neural.model.Neuron; import mulan.core.ArgumentNullException; import mulan.core.WekaException; import mulan.data.MultiLabelInstances; import mulan.evaluation.loss.RankingLossFunction; import weka.core.Attribute; import weka.core.Instance; import weka.core.Instances; import weka.core.TechnicalInformation; import weka.core.TechnicalInformation.Field; import weka.core.TechnicalInformation.Type; import weka.filters.Filter; import weka.filters.unsupervised.attribute.NominalToBinary; /** * Implementation of Multiclass Multilabel Perceptrons learner. For more information refer * to technical paper describing the learner. * * <!-- technical-bibtex-start --> * * <!-- technical-bibtex-end --> * * @author Jozef Vilcek */ public class MMPLearner extends MultiLabelLearnerBase { /** Version UID for serialization */ private static final long serialVersionUID = 2221778416856852684L; /** The bias value for tempPerceptrons */ private static final double PERCEP_BIAS = 1; /** * List of tempPerceptrons representing model of the learner. One for each label. * They are ordered in same sequence as labels observed from training data. **/ private List<Neuron> perceptrons; // TODO: Can not use current normalization filer as MMP is incremental algorithm and so, filter must be too // Investigate first, if we want to support normalization // /** Determines if feature attributes has to be normalized prior to learning */ // private boolean normalizeAttributes = true; private NormalizationFilter normalizer; /** The number of training epochs to perform with trainig data during the model learning / building */ private int epochs = 1; /** Indicates whether any nominal attributes from input data set has to be converted to binary */ private boolean convertNomToBin = true; /** Filter used for conversion of nominal attributes to binary (if enabled) */ private NominalToBinary nomToBinFilter; /** The measure to be used to judge the performance of ranking when learning the model */ private final RankingLossFunction lossFunction; /** The name of a model update rule used to update the model when learning from training data */ private final MMPUpdateRuleType mmpUpdateRule; /** * The flag indicating if initialization with of learner first learning data samples already * took place. This is because the {@link MMPLearner} is online and updatable. */ private boolean isInitialized = false; private final Long randomnessSeed; /** * Creates a new instance of {@link MMPLearner}. * * @param lossMeasure the loss measure to be used when judging * ranking performance in learning process * @param modelUpdateRule */ public MMPLearner(RankingLossFunction lossMeasure, MMPUpdateRuleType modelUpdateRule) { if (lossMeasure == null) { throw new ArgumentNullException("lossMeasure"); } if (modelUpdateRule == null) { throw new ArgumentNullException("modelUpdateRule"); } mmpUpdateRule = modelUpdateRule; this.lossFunction = lossMeasure; randomnessSeed = null; } /** * Creates a new instance of {@link MMPLearner}. * * @param lossMeasure the loss measure to be used when judging * ranking performance in learning process * @param modelUpdateRule * @param randomnessSeed the seed value for pseudo-random generator */ public MMPLearner(RankingLossFunction lossMeasure, MMPUpdateRuleType modelUpdateRule, long randomnessSeed) { if (lossMeasure == null) { throw new ArgumentNullException("lossMeasure"); } if (modelUpdateRule == null) { throw new ArgumentNullException("modelUpdateRule"); } mmpUpdateRule = modelUpdateRule; this.lossFunction = lossMeasure; this.randomnessSeed = randomnessSeed; } /** * Sets the number of training epochs. Must be greater than 0.<br/> * Default value is 1. * * @param epochs the number of training epochs * @throws IllegalArgumentException if passed value is invalid */ public void setTrainingEpochs(int epochs) { if (epochs <= 0) { throw new IllegalArgumentException("The number of training epochs must be greater than zero. " + "Entered value is : " + epochs); } this.epochs = epochs; } /** * Gets number of training epochs. * Default value is 1. * @return training epochs */ public int getTrainingEpochs() { return epochs; } /** * Sets whether nominal attributes from input data set has to be converted to binary * prior to learning (and respectively making a prediction). * * @param convert flag indicating whether conversion should take place */ public void setConvertNominalToBinary(boolean convert) { convertNomToBin = convert; } /** * Gets a value indication whether conversion of nominal attributes from input data * set to binary takes place prior to learning (and respectively making a prediction). * * @return value indication whether conversion takes place */ public boolean getConvertNominalToBinary() { return convertNomToBin; } // /** // * Sets whether feature attributes should be normalized prior to learning. // * Normalization is performed on numeric attributes to the range <-1,1>.<br/> // * When making prediction, attributes of passed input instance are also // * normalized prior to making prediction.<br/> // * Default value is <code>true</code> (normalization of attributes takes place). // * // * @param normalize flag if normalization of feature attributes should be performed // */ // public void setNormalizeAttributes(boolean normalize) { // normalizeAttributes = normalize; // } // // /** // * Gets whether normalization of feature attributes takes place prior to learning. // * @return whether normalization of feature attributes takes place prior to learning // */ // public boolean getNormalizeAttributes() { // return normalizeAttributes; // } @Override public boolean isUpdatable() { return true; } @Override protected void buildInternal(MultiLabelInstances trainingSet) throws Exception { trainingSet = trainingSet.clone(); List<DataPair> trainData = prepareData(trainingSet); int numFeatures = trainData.get(0).getInput().length; if (!isInitialized) { perceptrons = initializeModel(numFeatures, numLabels); isInitialized = true; } ModelUpdateRule modelUpdateRule = getModelUpdateRule(lossFunction); for(int iter = 0; iter < epochs; iter++){ for (DataPair dataItem : trainData) { modelUpdateRule.process(dataItem, null); } } } @Override public MultiLabelOutput makePredictionInternal(Instance instance) throws InvalidDataException { double[] input = getFeatureVector(instance); // update model prediction on raking for given example double[] labelConfidences = new double[numLabels]; for (int index = 0; index < numLabels; index++) { Neuron perceptron = perceptrons.get(index); labelConfidences[index] = perceptron.processInput(input); } MultiLabelOutput mlOut = new MultiLabelOutput( MultiLabelOutput.ranksFromValues(labelConfidences)); return mlOut; } @Override public TechnicalInformation getTechnicalInformation() { TechnicalInformation technicalInfo = new TechnicalInformation(Type.ARTICLE); technicalInfo.setValue(Field.AUTHOR, "Koby Crammer, Yoram Singer"); technicalInfo.setValue(Field.YEAR, "2003"); technicalInfo.setValue(Field.TITLE, "A Family of Additive Online Algorithms for Category Ranking."); technicalInfo.setValue(Field.JOURNAL, "Journal of Machine Learning Research"); technicalInfo.setValue(Field.VOLUME, "3(6)"); technicalInfo.setValue(Field.PAGES, "1025�1058"); return technicalInfo; } private List<Neuron> initializeModel(int numFeatures, int numLabels) { Random random = randomnessSeed == null ? null : new Random(randomnessSeed); List<Neuron> tempPerceptrons = new ArrayList<Neuron>(numLabels); for (int i = 0; i < numLabels; i++) { tempPerceptrons.add(new Neuron(new ActivationLinear(), numFeatures, PERCEP_BIAS, random)); } return tempPerceptrons; } private ModelUpdateRule getModelUpdateRule(RankingLossFunction lossMeasure) { switch (mmpUpdateRule) { case UniformUpdate: return new MMPUniformUpdateRule(perceptrons, lossMeasure); case MaxUpdate: return new MMPMaxUpdateRule(perceptrons, lossMeasure); case RandomizedUpdate: return new MMPRandomizedUpdateRule(perceptrons, lossMeasure); default: throw new IllegalArgumentException(String.format( "The specified model update rule '%s' is not supported.", mmpUpdateRule)); } } /** * Prepares {@link MultiLabelInstances} data set for a learning:<br/> * - feature attributes are checked for correct format (nominal of numeric) * - nominal feature attributes are converted to binary * - feature attributes are normalized if normalization is enabled * - instances are converted to {@link DataPair} instances (convenience for manipulation) */ private List<DataPair> prepareData(MultiLabelInstances mlData) { Set<Attribute> featureAttr = mlData.getFeatureAttributes(); String nominalAttrRange = ensureAttributesFormat(featureAttr); Instances dataSet = mlData.getDataSet(); // if configured, perform conversion of nominal attributes to binary if (convertNomToBin && nominalAttrRange.length() > 0) { // create a filter definition for the first time if (!isInitialized) { nomToBinFilter = new NominalToBinary(); try { nomToBinFilter = new NominalToBinary(); nomToBinFilter.setAttributeIndices(nominalAttrRange.toString()); nomToBinFilter.setInputFormat(dataSet); } catch (Exception exception) { nomToBinFilter = null; if (getDebug()) { debug("Failed to create NominalToBinary filter for the input instances data. " + "Error message: " + exception.getMessage()); } throw new WekaException("Failed to create NominalToBinary filter for the input instances data.", exception); } } // apply nominal -> binary filter to the data try { dataSet = Filter.useFilter(dataSet, nomToBinFilter); mlData = mlData.reintegrateModifiedDataSet(dataSet); this.labelIndices = mlData.getLabelIndices(); } catch (Exception exception) { if (getDebug()) { debug("Failed to apply NominalToBinary filter to the input instances data. " + "Error message: " + exception.getMessage()); } throw new WekaException("Failed to apply NominalToBinary filter to the input instances data.", exception); } } return DataPair.createDataPairs(mlData, false); } /** * Ensures that all attributes are nominal or numeric. In case they are not, * exception is thrown. * * @param attributes attributes to be checked * @return the string with indices of nominal attributes, which can by used for * nominal to binary transformation of attributes */ private String ensureAttributesFormat(Set<Attribute> attributes) { // TODO: where should the check takes place ... should be general and // and use declaratively "capabilities" similar to weka StringBuilder nominalAttrRange = new StringBuilder(); String rangeDelimiter = ","; for (Attribute attribute : attributes) { if (!attribute.isNumeric()) { if (attribute.isNominal()) { nominalAttrRange.append((attribute.index() + 1) + rangeDelimiter); } else { // fail check if any other attribute type than nominal or numeric is used //return false; } } } if (nominalAttrRange.length() > 0) { nominalAttrRange.deleteCharAt(nominalAttrRange.lastIndexOf(rangeDelimiter)); } return nominalAttrRange.toString(); } private double[] getFeatureVector(Instance inputInstance) { if (convertNomToBin && nomToBinFilter != null) { try { nomToBinFilter.input(inputInstance); inputInstance = nomToBinFilter.output(); inputInstance.setDataset(null); } catch (Exception ex) { throw new InvalidDataException("The input instance for prediction is invalid. " + "Instance is not consistent with the data the model was built for."); } } // check if number in attributes is at least equal to model input int numAttributes = inputInstance.numAttributes(); int modelInputDim = perceptrons.get(0).getWeights().length - 1; if (numAttributes < modelInputDim) { throw new InvalidDataException("Input instance do not have enough attributes " + "to be processed by the model. Instance is not consistent with the data the model was built for."); } // if instance has more attributes than model input, we assume that true outputs // are there, so we remove them List<Integer> labelIndices = new ArrayList<Integer>(); boolean labelsAreThere = false; if (numAttributes > modelInputDim) { for (int index : this.labelIndices) { labelIndices.add(index); } labelsAreThere = true; } double[] inputPattern = new double[modelInputDim]; int indexCounter = 0; for (int attrIndex = 0; attrIndex < numAttributes; attrIndex++) { if (labelsAreThere && labelIndices.contains(attrIndex)) { continue; } inputPattern[indexCounter] = inputInstance.value(attrIndex); indexCounter++; } return inputPattern; } }