/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* MultiLabelKNN.java
* Copyright (C) 2009-2010 Aristotle University of Thessaloniki, Thessaloniki, Greece
*/
package mulan.classifier.lazy;
import mulan.classifier.MultiLabelLearnerBase;
import mulan.data.MultiLabelInstances;
import weka.core.DistanceFunction;
import weka.core.EuclideanDistance;
import weka.core.Instances;
import weka.core.neighboursearch.LinearNNSearch;
/**
* Superclass of all KNN based multi-label algorithms
*
* @author Eleftherios Spyromitros-Xioufis ( espyromi@csd.auth.gr )
*
*/
@SuppressWarnings("serial")
public abstract class MultiLabelKNN extends MultiLabelLearnerBase {
/** Whether the neighbors should be distance-weighted. */
protected int distanceWeighting;
/** no weighting. */
public static final int WEIGHT_NONE = 1;
/** weight by 1/distance. */
public static final int WEIGHT_INVERSE = 2;
/** weight by 1-distance. */
public static final int WEIGHT_SIMILARITY = 4;
// TODO weight each neighbor's vote according to the inverse square of its distance
/**
* Number of neighbors used in the k-nearest neighbor algorithm
*/
protected int numOfNeighbors;
/**
* Class implementing the brute force search algorithm for nearest neighbor search. Default
* value is true.
*/
protected LinearNNSearch lnn = null;
/**
* Implementing Euclidean distance (or similarity) function.
*/
protected DistanceFunction dfunc = null;
public void setDfunc(DistanceFunction dfunc) {
this.dfunc = dfunc;
}
/**
* The training instances
*/
protected Instances train;
/**
* The default constructor
*/
public MultiLabelKNN() {
this.numOfNeighbors = 10;
dfunc = new EuclideanDistance();
}
/**
* Initializes the number of neighbors
*
* @param numOfNeighbors the number of neighbors
*/
public MultiLabelKNN(int numOfNeighbors) {
this.numOfNeighbors = numOfNeighbors;
dfunc = new EuclideanDistance();
}
protected void buildInternal(MultiLabelInstances trainSet) throws Exception {
train = new Instances(trainSet.getDataSet());
// label attributes don't influence distance estimation
String labelIndicesString = "";
for (int i = 0; i < numLabels - 1; i++) {
labelIndicesString += (labelIndices[i] + 1) + ",";
}
labelIndicesString += (labelIndices[numLabels - 1] + 1);
dfunc.setAttributeIndices(labelIndicesString);
dfunc.setInvertSelection(true);
lnn = new LinearNNSearch();
lnn.setDistanceFunction(dfunc);
lnn.setInstances(train);
lnn.setMeasurePerformance(false);
}
@Override
public boolean isUpdatable() {
return true;
}
/**
* @param distanceWeighting the distanceWeighting to set
*/
public void setDistanceWeighting(int distanceWeighting) {
this.distanceWeighting = distanceWeighting;
}
}