/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* MultiLabelLearnerBase.java
* Copyright (C) 2009-2010 Aristotle University of Thessaloniki, Thessaloniki, Greece
*/
package mulan.classifier;
import java.io.Serializable;
import java.util.Date;
import mulan.core.ArgumentNullException;
import mulan.data.MultiLabelInstances;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.SerializedObject;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
/**
* Common root base class for all multi-label learner types.
* Provides default implementation of {@link MultiLabelLearner} interface.
*
* @author Robert Friberg
* @author Jozef Vilcek
* @author Grigorios Tsoumakas
*/
public abstract class MultiLabelLearnerBase
implements TechnicalInformationHandler, MultiLabelLearner, Serializable {
private boolean isModelInitialized = false;
/**
* The number of labels the learner can handle.
* The number of labels are determined form the training data when learner is build.
*/
protected int numLabels;
/**
* An array containing the indexes of the label attributes within the
* {@link Instances} object of the training data in increasing order. The same
* order will be followed in the arrays of predictions given by each learner
* in the {@link MultiLabelOutput} object.
*/
protected int[] labelIndices;
/**
* An array containing the indexes of the feature attributes within the
* {@link Instances} object of the training data in increasing order.
*/
protected int[] featureIndices;
private boolean isDebug = false;
/**
* Gets the {@link TechnicalInformation} for the current learner type.
*
* @return technical information
*/
public abstract TechnicalInformation getTechnicalInformation();
public boolean isUpdatable() {
/** as default learners are assumed not to be updatable */
return false;
}
public final void build(MultiLabelInstances trainingSet) throws Exception {
if (trainingSet == null) {
throw new ArgumentNullException("trainingSet");
}
isModelInitialized = false;
numLabels = trainingSet.getNumLabels();
labelIndices = trainingSet.getLabelIndices();
featureIndices = trainingSet.getFeatureIndices();
buildInternal(trainingSet);
isModelInitialized = true;
}
/**
* Learner specific implementation of building the model from {@link MultiLabelInstances}
* training data set. This method is called from {@link #build(MultiLabelInstances)} method,
* where behavior common across all learners is applied.
*
* @param trainingSet the training data set
* @throws Exception if learner model was not created successfully
*/
protected abstract void buildInternal(MultiLabelInstances trainingSet) throws Exception;
/**
* Gets whether learner's model is initialized by {@link #build(MultiLabelInstances)}.
* This is used to check if {@link #makePrediction(weka.core.Instance)} can be processed.
* @return isModelInitialized returns true if the model has been initialized
*/
protected boolean isModelInitialized() {
return isModelInitialized;
}
public final MultiLabelOutput makePrediction(Instance instance)
throws Exception, InvalidDataException, ModelInitializationException {
if (instance == null) {
throw new ArgumentNullException("instance");
}
if (!isModelInitialized()) {
throw new ModelInitializationException("The model has not been trained.");
}
return makePredictionInternal(instance);
}
/**
* Learner specific implementation for predicting on specified data based on trained model.
* This method is called from {@link #makePrediction(weka.core.Instance)} which guards for model
* initialization and apply common handling/behavior.
*
* @param instance the data instance to predict on
* @throws Exception if an error occurs while making the prediction.
* @throws InvalidDataException if specified instance data is invalid and can not be processed by the learner
*/
protected abstract MultiLabelOutput makePredictionInternal(Instance instance) throws Exception, InvalidDataException;
/**
* Set debugging mode.
*
* @param debug <code>true</code> if debug output should be printed
*/
public void setDebug(boolean debug) {
isDebug = debug;
}
/**
* Get whether debugging is turned on.
*
* @return <code>true</code> if debugging output is on
*/
public boolean getDebug() {
return isDebug;
}
/**
* Writes the debug message string to the console output
* if debug for the learner is enabled.
*
* @param msg the debug message
*/
protected void debug(String msg) {
if (!getDebug()) {
return;
}
System.err.println("" + new Date() + ": " + msg);
}
public MultiLabelLearner makeCopy() throws Exception {
return (MultiLabelLearner) new SerializedObject(this).getObject();
}
}