/*
* RapidMiner
*
* Copyright (C) 2001-2008 by Rapid-I and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapid-i.com
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.meta;
import java.util.Arrays;
import java.util.List;
import com.rapidminer.example.AttributeWeights;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.IOContainer;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorChain;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.ValueDouble;
import com.rapidminer.operator.condition.InnerOperatorCondition;
import com.rapidminer.operator.condition.LastInnerOperatorCondition;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.ParameterTypeNumber;
import com.rapidminer.parameter.ParameterTypeString;
/**
* Performs a feature selection guided by the AttributeWeights. Forward
* selection means that features with the highest weight-value are selected
* first (starting with an empty selection). Backward elemination means that
* features with the smallest weight value are eleminated first (starting with
* the full feature set).
*
* @author Daniel Hakenjos, Ingo Mierswa
* @version $Id: WeightOptimization.java,v 1.3 2006/04/05 08:57:26 ingomierswa
* Exp $
*/
public class WeightOptimization extends OperatorChain {
/** The parameter name for "The parameter to set the weight value" */
public static final String PARAMETER_PARAMETER = "parameter";
/** The parameter name for "Forward selection or backward elimination." */
public static final String PARAMETER_SELECTION_DIRECTION = "selection_direction";
/** The parameter name for "The minimum difference between two weights." */
public static final String PARAMETER_MIN_DIFF = "min_diff";
/** The parameter name for "Number iterations without performance improvement." */
public static final String PARAMETER_ITERATIONS_WITHOUT_IMPROVEMENT = "iterations_without_improvement";
private static final Class[] INPUT_CLASSES = { ExampleSet.class, AttributeWeights.class };
private static final Class[] OUTPUT_CLASSES = { ParameterSet.class, PerformanceVector.class, AttributeWeights.class };
private static final String[] DIRECTIONS = new String[] { "forward selection", "backward elimination" };
private ParameterSet best;
private double[] weights;
private double currentweight, lastweight, lastperf, bestweight;
// The Operator to set the weight
private Operator operator;
// the parameter of the operator
private String parameter;
// the minimum difference between two weights
private double min_diff;
public WeightOptimization(OperatorDescription description) {
super(description);
addValue(new ValueDouble("performance", "performance of the last evaluated weight") {
public double getDoubleValue() {
return lastperf;
}
});
addValue(new ValueDouble("best_performance", "best performance") {
public double getDoubleValue() {
if (best != null)
return best.getPerformance().getMainCriterion().getAverage();
else
return Double.NaN;
}
});
addValue(new ValueDouble("weight", "currently used weight") {
public double getDoubleValue() {
return lastweight;
}
});
}
public IOObject[] apply() throws OperatorException {
IOContainer input = getInput();
input.get(ExampleSet.class).clone();
AttributeWeights attweights = input.get(AttributeWeights.class);
Object[] names = attweights.getAttributeNames().toArray();
weights = new double[names.length];
for (int i = 0; i < names.length; i++) {
weights[i] = Math.abs(attweights.getWeight((String) names[i]));
}
Arrays.sort(weights);
int direction = getParameterAsInt(PARAMETER_SELECTION_DIRECTION);
int max_iter_without_improvement = getParameterAsInt(PARAMETER_ITERATIONS_WITHOUT_IMPROVEMENT);
getParametersToOptimize();
operator.getParameters().setParameter("weight_relation", "greater equals");
int weightindex = weights.length - 1;
if (direction == 1) {
// backward elimination
weightindex = 0;
}
lastweight = Double.POSITIVE_INFINITY;
lastperf = Double.NaN;
currentweight = weights[weightindex];
bestweight = currentweight;
best = null;
IOContainer container;
PerformanceVector performance;
int iter = 0;
int iter_without_improvement = 0;
while (true) {
iter++;
log("Iteration: " + iter);
log("Using weight");
// set the weight
operator.getParameters().setParameter(parameter, Double.toString(currentweight));
log(operator + "." + parameter + " = " + currentweight);
log("Number attributes: " + (weights.length - weightindex));
container = input.copy();
// apply the input to the inner operators
for (int i = 0; i < getNumberOfOperators(); i++) {
container = getOperator(i).apply(container);
}
// get the PerformanceVector
if (!container.contains(PerformanceVector.class)) {
// PerformanceVector should be available --> see
// checkIO(IOContainer);
throw new OperatorException("Cannot find PerformanceVector!");
}
performance = container.get(PerformanceVector.class);
lastperf = performance.getMainCriterion().getFitness();
log("Performance: " + performance.toResultString());
if ((best == null) || (performance.compareTo(best.getPerformance()) > 0)) {
String bestValue = Double.toString(currentweight);
bestweight = currentweight;
best = new ParameterSet(new Operator[] { operator }, new String[] { parameter }, new String[] { bestValue }, performance);
iter_without_improvement = 0;
} else {
iter_without_improvement++;
}
if (iter_without_improvement >= max_iter_without_improvement) {
break;
}
// next weight
if (((direction == 0) && (weightindex == 0)) || ((direction == 1) && (weightindex == names.length - 1))) {
inApplyLoop();
break;
}
if (direction == 0) {
weightindex--;
} else {
weightindex++;
}
lastweight = currentweight;
currentweight = weights[weightindex];
while (Math.abs(currentweight - lastweight) < min_diff) {
if (weightindex == 0) {
inApplyLoop();
break;
}
if (direction == 0) {
weightindex--;
} else {
weightindex++;
}
lastweight = currentweight;
currentweight = weights[weightindex];
}
inApplyLoop();
}
double w;
for (int i = 0; i < names.length; i++) {
w = attweights.getWeight((String) names[i]);
if (w < bestweight) {
attweights.setWeight((String) names[i], 0.0d);
}
}
input.remove(AttributeWeights.class);
return new IOObject[] { best, best.getPerformance(), attweights };
}
public InnerOperatorCondition getInnerOperatorCondition() {
return new LastInnerOperatorCondition(new Class[] { PerformanceVector.class });
}
public int getMaxNumberOfInnerOperators() {
return Integer.MAX_VALUE;
}
public int getMinNumberOfInnerOperators() {
return 1;
}
public Class<?>[] getInputClasses() {
return INPUT_CLASSES;
}
public Class<?>[] getOutputClasses() {
return OUTPUT_CLASSES;
}
public void getParametersToOptimize() throws OperatorException {
min_diff = getParameterAsDouble(PARAMETER_MIN_DIFF);
String keyvalue = getParameterAsString(PARAMETER_PARAMETER);
String[] parameter = keyvalue.split("\\.");
if ((parameter.length < 2) || (parameter.length > 3)) {
throw new UserError(this, 907, keyvalue);
}
operator = getProcess().getOperator(parameter[0]);
if (operator == null) {
throw new UserError(this, 109, parameter[0]);
}
ParameterType targetType = operator.getParameters().getParameterType(parameter[1]);
this.parameter = parameter[1];
if (targetType == null) {
throw new UserError(this, 906, parameter[0] + "." + parameter[1]);
}
if (!(targetType instanceof ParameterTypeNumber)) {
throw new UserError(this, 909, parameter[0] + "." + parameter[1]);
}
}
public List<ParameterType> getParameterTypes() {
List<ParameterType> types = super.getParameterTypes();
ParameterType type = new ParameterTypeString(PARAMETER_PARAMETER, "The parameter to set the weight value");
type.setExpert(false);
types.add(type);
type = new ParameterTypeCategory(PARAMETER_SELECTION_DIRECTION, "Forward selection or backward elimination.", DIRECTIONS, 0);
type.setExpert(false);
types.add(type);
type = new ParameterTypeDouble(PARAMETER_MIN_DIFF, "The minimum difference between two weights.", 0.0d, Double.POSITIVE_INFINITY, 1.0e-10);
type.setExpert(false);
types.add(type);
type = new ParameterTypeInt(PARAMETER_ITERATIONS_WITHOUT_IMPROVEMENT, "Number iterations without performance improvement.", 1, Integer.MAX_VALUE, 1);
type.setExpert(false);
types.add(type);
return types;
}
}