package mulan.classifier.transformation; import java.util.Random; import java.util.logging.Level; import java.util.logging.Logger; import mulan.classifier.InvalidDataException; import mulan.classifier.MultiLabelOutput; import mulan.data.MultiLabelInstances; import weka.classifiers.AbstractClassifier; import weka.classifiers.Classifier; import weka.core.Instance; import weka.core.Instances; import weka.core.TechnicalInformation; import weka.core.TechnicalInformation.Field; import weka.core.TechnicalInformation.Type; import weka.filters.Filter; import weka.filters.unsupervised.instance.RemovePercentage; /** * <!-- globalinfo-start --> * * <pre> * Class implementing the EPS algorithm, which constructs an ensemble of * Pruned Sets models, via sampling * </pre> * * <!-- globalinfo-end --> * * <!-- technical-bibtex-start --> BibTeX: * * <!-- technical-bibtex-end --> * * @author Emmanouela Stachtiari * @author Grigorios Tsoumakas * @version June 6, 2010 */ public class EnsembleOfPrunedSets extends TransformationBasedMultiLabelLearner { /** Parameter for the threshold of discretization of prediction output */ protected double threshold; /** Parameter for the number of models that constitute the ensemble*/ protected int numOfModels; /** Percentage of data */ protected double percentage; /** The models in the ensemble */ protected PrunedSets[] ensemble; /** Random number generator */ protected Random rand; /** * @param aNumOfModels the number of models in the ensemble * @param aStrategy pruned sets strategy * @param aPercentage percentage of data to sample * @param aP pruned sets parameter p * @param aB pruned sets parameter b * @param baselearner the base learner * @param aThreshold the threshold for producing bipartitions */ public EnsembleOfPrunedSets(double aPercentage, int aNumOfModels, double aThreshold, int aP, PrunedSets.Strategy aStrategy, int aB, Classifier baselearner) { super(baselearner); numOfModels = aNumOfModels; threshold = aThreshold; percentage = aPercentage; ensemble = new PrunedSets[numOfModels]; for (int i = 0; i < numOfModels; i++) { try { ensemble[i] = new PrunedSets(AbstractClassifier.makeCopy(baselearner), aP, aStrategy, aB); } catch (Exception ex) { Logger.getLogger(EnsembleOfPrunedSets.class.getName()).log(Level.SEVERE, null, ex); } } rand = new Random(1); } @Override protected void buildInternal(MultiLabelInstances trainingSet) throws Exception { Instances dataSet = new Instances(trainingSet.getDataSet()); for (int i = 0; i < numOfModels; i++) { dataSet.randomize(rand); RemovePercentage rmvp = new RemovePercentage(); rmvp.setInputFormat(dataSet); rmvp.setPercentage(percentage); rmvp.setInvertSelection(true); Instances trainDataSet = Filter.useFilter(dataSet, rmvp); MultiLabelInstances train = new MultiLabelInstances(trainDataSet, trainingSet.getLabelsMetaData()); ensemble[i].build(train); } } /** * Returns an instance of a TechnicalInformation object, containing detailed * information about the technical background of this class, e.g., paper * reference or book this class is based on. * * @return the technical information about this class */ @Override public TechnicalInformation getTechnicalInformation() { TechnicalInformation result = new TechnicalInformation(Type.CONFERENCE); result.setValue(Field.AUTHOR, "Read, Jesse"); result.setValue(Field.TITLE, "Multi-label Classification using Ensembles of Pruned Sets"); result.setValue(Field.PAGES, "995-1000"); result.setValue(Field.BOOKTITLE, "ICDM'08: Eighth IEEE International Conference on Data Mining"); result.setValue(Field.YEAR, "2008"); return result; } @Override protected MultiLabelOutput makePredictionInternal(Instance instance) throws Exception, InvalidDataException { int[] sumVotes = new int[numLabels]; for (int i = 0; i < numOfModels; i++) { MultiLabelOutput ensembleMLO = ensemble[i].makePrediction(instance); boolean[] bip = ensembleMLO.getBipartition(); for (int j = 0; j < sumVotes.length; j++) { sumVotes[j] += bip[j] == true ? 1 : 0; } } double[] confidence = new double[numLabels]; for (int j = 0; j < sumVotes.length; j++) { confidence[j] = (double) sumVotes[j] / (double) numOfModels; } MultiLabelOutput mlo = new MultiLabelOutput(confidence, threshold); return mlo; } }