/**
* Copyright (C) 2001-2017 by RapidMiner and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapidminer.com
*
* This program is free software: you can redistribute it and/or modify it under the terms of the
* GNU Affero General Public License as published by the Free Software Foundation, either version 3
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
* even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along with this program.
* If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.features.selection;
import java.util.Iterator;
import java.util.List;
import java.util.Vector;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Statistics;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.OperatorVersion;
import com.rapidminer.operator.ValueDouble;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.tools.RandomGenerator;
import com.rapidminer.tools.Tools;
/**
* <p>
* Removes (un-) correlated features due to the selected filter relation. The procedure is quadratic
* in number of attributes. In order to get more stable results, the original, random, and reverse
* order of attributes is available.
* </p>
*
* <p>
* Please note that this operator might fail in some cases when the attributes should be filtered
* out, e.g. it might not be able to remove for example all negative correlated features. The reason
* for this behaviour seems to be that for the complete m x m - matrix of correlations (for m
* attributes) the correlations will not be recalculated and hence not checked if one of the
* attributes of the current pair was already marked for removal. That means for three attributes
* a1, a2, and a3 that it might be that a2 was already ruled out by the negative correlation with a1
* and is now not able to rule out a3 any longer.
* </p>
*
* <p>
* The used correlation function is the Pearson correlation.
* </p>
*
* @author Daniel Hakenjos, Ingo Mierswa
*/
public class RemoveCorrelatedFeatures extends AbstractFeatureSelection {
/**
* The parameter name for "Use this correlation for the filter relation."
*/
public static final String PARAMETER_CORRELATION = "correlation";
/**
* The parameter name for "Removes one of two features if their correlation fulfill this
* relation."
*/
public static final String PARAMETER_FILTER_RELATION = "filter_relation";
/**
* The parameter name for "The algorithm takes this attribute order to calculate
* correlation and filter."
*/
public static final String PARAMETER_ATTRIBUTE_ORDER = "attribute_order";
/**
* The parameter name for "Indicates if the absolute value of the correlations should be
* used for comparison."
*/
public static final String PARAMETER_USE_ABSOLUTE_CORRELATION = "use_absolute_correlation";
public static final OperatorVersion VERSION_DETERMINISTIC_RANDOM_NUMBERS = new OperatorVersion(5, 2, 1);
private static final String[] ORDER = new String[] { "original", "random", "reverse" };
private static final int ORDER_ORIGINAL = 0;
private static final int ORDER_RANDOM = 1;
private static final int ORDER_REVERSE = 2;
private static final String[] FILTER_RELATIONS = new String[] { "greater", "greater equals", "equals", "less equals",
"less" };
private static final int GREATER = 0;
private static final int GREATER_EQUALS = 1;
private static final int EQUALS = 2;
private static final int LESS_EQUALS = 3;
private static final int LESS = 4;
/** The number of removed features (for logging as value, see constructor.) */
private int removedFeatures = 0;
public RemoveCorrelatedFeatures(OperatorDescription description) {
super(description);
addValue(new ValueDouble("features_removed", "Number of removed features") {
@Override
public double getDoubleValue() {
return removedFeatures;
}
});
}
@Override
public ExampleSet apply(ExampleSet exampleSet) throws OperatorException {
getProgress().setTotal(100);
exampleSet.recalculateAllAttributeStatistics();
double[] means = new double[exampleSet.getAttributes().size()];
double[] deviations = new double[exampleSet.getAttributes().size()];
boolean[] removeFeature = new boolean[exampleSet.getAttributes().size()];
int[] attributeIndex = new int[exampleSet.getAttributes().size()];
getProgress().setCompleted(3);
int index = 0;
for (Attribute attribute : exampleSet.getAttributes()) {
means[index] = exampleSet.getStatistics(attribute, Statistics.AVERAGE);
deviations[index] = Math.sqrt(exampleSet.getStatistics(attribute, Statistics.VARIANCE));
removeFeature[index] = false;
attributeIndex[index] = index;
index++;
}
double[][] samples = new double[exampleSet.size()][exampleSet.getAttributes().size()];
int d = 0;
for (Attribute attribute : exampleSet.getAttributes()) {
int counter = 0;
for (Example example : exampleSet) {
samples[counter++][d] = example.getValue(attribute);
}
d++;
}
// attribute order
int order = getParameterAsInt(PARAMETER_ATTRIBUTE_ORDER);
if (order == ORDER_ORIGINAL) {
for (int i = 0; i < exampleSet.getAttributes().size(); i++) {
attributeIndex[i] = i;
}
} else if (order == ORDER_RANDOM) {
// random attributes
Vector<Integer> vector = new Vector<Integer>();
for (int i = 0; i < exampleSet.getAttributes().size(); i++) {
vector.add(i);
}
int vindex;
for (int i = 0; i < exampleSet.getAttributes().size(); i++) {
RandomGenerator randomGenerator = RandomGenerator.getRandomGenerator(this);
if (getCompatibilityLevel().isAtMost(VERSION_DETERMINISTIC_RANDOM_NUMBERS)) {
vindex = (int) Math.floor(Math.random() * vector.size());
attributeIndex[i] = vector.remove(vindex).intValue();
} else {
vindex = randomGenerator.nextInt(vector.size());
attributeIndex[i] = vector.remove(vindex);
}
}
} else if (order == ORDER_REVERSE) {
for (int i = 0; i < exampleSet.getAttributes().size(); i++) {
attributeIndex[i] = exampleSet.getAttributes().size() - 1 - i;
}
}
// absolute value
boolean absolute = getParameterAsBoolean(PARAMETER_USE_ABSOLUTE_CORRELATION);
// filter relation
int relation = getParameterAsInt(PARAMETER_FILTER_RELATION);
// filtering
double threshold = getParameterAsDouble(PARAMETER_CORRELATION);
if (absolute && threshold < 0.0d) {
threshold = Math.abs(threshold);
logWarning("Correlation value is lower zero. Setting to absolute: " + threshold);
}
Attribute[] allAttributes = exampleSet.getAttributes().createRegularAttributeArray();
getProgress().setCompleted(5);
for (int i = 0; i < exampleSet.getAttributes().size() - 1; i++) {
if (i % 10 == 0) {
getProgress().setCompleted(5 + (int) (i * 95L / (exampleSet.getAttributes().size() - 1)));
}
if (removeFeature[attributeIndex[i]] == true) {
continue;
}
for (int j = i + 1; j < exampleSet.getAttributes().size(); j++) {
if (removeFeature[attributeIndex[j]] == true) {
continue;
}
double correlation = getCorrelation(samples, means, deviations, attributeIndex[i], attributeIndex[j]);
if (absolute) {
correlation = Math.abs(correlation);
}
if (fulfillRelation(correlation, threshold, relation)) {
removeFeature[attributeIndex[j]] = true;
String first = allAttributes[attributeIndex[i]].getName();
String second = allAttributes[attributeIndex[j]].getName();
log("Removed Attribute : " + second + " --> correlation(" + first + "," + second + ")=" + correlation);
}
}
}
// actual removal (and counter)
this.removedFeatures = 0;
index = 0;
Iterator<Attribute> iterator = exampleSet.getAttributes().iterator();
while (iterator.hasNext()) {
iterator.next();
if (removeFeature[index]) {
iterator.remove();
this.removedFeatures++;
}
index++;
}
log("Removed " + this.removedFeatures + "features." + Tools.getLineSeparator() + "ExampleSet has now "
+ exampleSet.getAttributes().size() + " features.");
return exampleSet;
}
private boolean fulfillRelation(double correlation, double threshold, int relation) {
switch (relation) {
case GREATER:
return correlation > threshold;
case GREATER_EQUALS:
return correlation >= threshold;
case EQUALS:
return correlation == threshold;
case LESS_EQUALS:
return correlation <= threshold;
case LESS:
return correlation < threshold;
}
return false;
}
/**
* Calculates the correlation between the two features
*
* @param att1
* index of feature 1
* @param att2
* index of feature 2
* @return the correlation in (-1.0,1.0)
*/
private double getCorrelation(double[][] samples, double[] means, double[] deviations, int att1, int att2) {
// calculate covariance
double covariance = 0.0d;
for (int j = 0; j < samples.length; j++) {
covariance += (samples[j][att1] - means[att1]) * (samples[j][att2] - means[att2]);
}
covariance = covariance / (samples.length - 1);
// calculate correlation
double correlation = 0.0d;
correlation = deviations[att1] * deviations[att2];
if (correlation == 0.0d) {
correlation = covariance;
} else {
correlation = covariance / correlation;
}
return correlation;
}
@Override
public List<ParameterType> getParameterTypes() {
List<ParameterType> types = super.getParameterTypes();
ParameterType type = new ParameterTypeDouble(PARAMETER_CORRELATION, "Use this correlation for the filter relation.",
-1.0d, 1.0d, 0.95d);
type.setExpert(false);
types.add(type);
type = new ParameterTypeCategory(PARAMETER_FILTER_RELATION,
"Removes one of two features if their correlation fulfill this relation.", FILTER_RELATIONS, 0);
types.add(type);
type = new ParameterTypeCategory(PARAMETER_ATTRIBUTE_ORDER,
"The algorithm takes this attribute order to calculate correlation and filter.", ORDER, 0);
types.add(type);
type = new ParameterTypeBoolean(PARAMETER_USE_ABSOLUTE_CORRELATION,
"Indicates if the absolute value of the correlations should be used for comparison.", true);
types.add(type);
types.addAll(RandomGenerator.getRandomGeneratorParameters(this));
return types;
}
@Override
public OperatorVersion[] getIncompatibleVersionChanges() {
OperatorVersion[] incompatibleVersionChanges = super.getIncompatibleVersionChanges();
OperatorVersion[] newIncompatibleVersionChanges = new OperatorVersion[incompatibleVersionChanges.length + 1];
for (int i = 0; i < incompatibleVersionChanges.length; ++i) {
newIncompatibleVersionChanges[i] = incompatibleVersionChanges[i];
}
newIncompatibleVersionChanges[newIncompatibleVersionChanges.length - 1] = VERSION_DETERMINISTIC_RANDOM_NUMBERS;
return newIncompatibleVersionChanges;
}
}