/**
* Copyright (C) 2001-2017 by RapidMiner and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapidminer.com
*
* This program is free software: you can redistribute it and/or modify it under the terms of the
* GNU Affero General Public License as published by the Free Software Foundation, either version 3
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
* even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along with this program.
* If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.features.aggregation;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import com.rapidminer.datatable.SimpleDataTable;
import com.rapidminer.datatable.SimpleDataTableRow;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.generator.AlgebraicOrGenerator;
import com.rapidminer.generator.FeatureGenerator;
import com.rapidminer.generator.MinMaxGenerator;
import com.rapidminer.operator.OperatorChain;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.OperatorVersion;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.metadata.AttributeMetaData;
import com.rapidminer.operator.ports.metadata.ExampleSetMetaData;
import com.rapidminer.operator.ports.metadata.ExampleSetPassThroughRule;
import com.rapidminer.operator.ports.metadata.SetRelation;
import com.rapidminer.operator.ports.metadata.SubprocessTransformRule;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeFile;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.conditions.EqualTypeCondition;
import com.rapidminer.tools.RandomGenerator;
import com.rapidminer.tools.Tools;
/**
* Performs an evolutionary feature aggregation. Each base feature is only allowed to be used as
* base feature, in one merged feature, or it may not be used at all.
*
* @author Ingo Mierswa
*/
public class EvolutionaryFeatureAggregation extends OperatorChain {
public static final String PARAMETER_POPULATION_CRITERIA_DATA_FILE = "population_criteria_data_file";
public static final String PARAMETER_AGGREGATION_FUNCTION = "aggregation_function";
public static final String PARAMETER_POPULATION_SIZE = "population_size";
public static final String PARAMETER_MAXIMUM_NUMBER_OF_GENERATIONS = "maximum_number_of_generations";
public static final String PARAMETER_SELECTION_TYPE = "selection_type";
public static final String PARAMETER_TOURNAMENT_FRACTION = "tournament_fraction";
public static final String PARAMETER_CROSSOVER_TYPE = "crossover_type";
public static final String PARAMETER_P_CROSSOVER = "p_crossover";
/** The names for the selection types. */
private static final String[] SELECTION_TYPES = { "tournament", "non-dominated" };
/** Indicates tournament selection. */
private static final int SELECTION_TOURNAMENT = 0;
/** Indicates NSGA-II selection. */
private static final int SELECTION_MO = 1;
/** The names for the aggregation functions. */
private static final String[] AGGREGATION_FUNCTIONS = { "maximum", "algebraic_or" };
/** Indicates the maximum aggregation function. */
private static final int AGGREGATION_MAX = 0;
/** Indicates the algebraic OR aggregation function. */
private static final int AGGREGATION_ALGEBRAIC = 1;
/** Compatibility Level for different number of generations */
private static final OperatorVersion CHANGE_7_3_1_NUMBER_OF_GENERATIONS = new OperatorVersion(7, 3, 1);
/** The original attributes. */
private Attribute[] allAttributes;
/** The used feature generator. */
private FeatureGenerator generator = new MinMaxGenerator(MinMaxGenerator.MAX);
/** The current generation. */
private int generation = 0;
/** The maximum generation. */
private int maxGeneration = 100;
private final InputPort exampleSetInput = getInputPorts().createPort("example set in", ExampleSet.class);
private final OutputPort innerExampleSetSource = getSubprocess(0).getInnerSources().createPort("example set source");
private final InputPort innerPerformanceSink = getSubprocess(0).getInnerSinks().createPort("performance vector sink",
PerformanceVector.class);
private final OutputPort exampleSetOutput = getOutputPorts().createPort("example set out");
private final OutputPort performanceOutput = getOutputPorts().createPort("performance vector out");
/** Creates a new evolutionary feature aggregation algorithm. */
public EvolutionaryFeatureAggregation(OperatorDescription description) {
super(description, "Performance Evaluation");
getTransformer().addRule(new ExampleSetPassThroughRule(exampleSetInput, exampleSetOutput, SetRelation.SUBSET) {
@Override
public ExampleSetMetaData modifyExampleSet(ExampleSetMetaData metaData) {
for (AttributeMetaData amd : metaData.getAllAttributes()) {
if (amd.isNumerical()) {
amd.setValueSetRelation(SetRelation.UNKNOWN);
}
}
return metaData;
}
});
getTransformer().addRule(new ExampleSetPassThroughRule(exampleSetInput, innerExampleSetSource, SetRelation.SUBSET) {
@Override
public ExampleSetMetaData modifyExampleSet(ExampleSetMetaData metaData) {
for (AttributeMetaData amd : metaData.getAllAttributes()) {
if (amd.isNumerical()) {
amd.setValueSetRelation(SetRelation.UNKNOWN);
}
}
return metaData;
}
});
getTransformer().addRule(new SubprocessTransformRule(getSubprocess(0)));
getTransformer().addPassThroughRule(innerPerformanceSink, performanceOutput);
}
@Override
public void doWork() throws OperatorException {
// init
ExampleSet exampleSet = exampleSetInput.getData(ExampleSet.class);
int popSize = getParameterAsInt(PARAMETER_POPULATION_SIZE);
this.generation = 0;
this.maxGeneration = getParameterAsInt(PARAMETER_MAXIMUM_NUMBER_OF_GENERATIONS);
if (getCompatibilityLevel().isAtMost(CHANGE_7_3_1_NUMBER_OF_GENERATIONS)) {
this.maxGeneration++;
}
int functionType = getParameterAsInt(PARAMETER_AGGREGATION_FUNCTION);
switch (functionType) {
case AGGREGATION_MAX:
this.generator = new MinMaxGenerator(MinMaxGenerator.MAX);
break;
case AGGREGATION_ALGEBRAIC:
this.generator = new AlgebraicOrGenerator();
break;
}
RandomGenerator random = RandomGenerator.getRandomGenerator(this);
this.allAttributes = new Attribute[exampleSet.getAttributes().size()];
int index = 0;
for (Attribute attribute : exampleSet.getAttributes()) {
allAttributes[index++] = attribute;
}
// crossover
AggregationCrossover crossover = new AggregationCrossover(getParameterAsInt(PARAMETER_CROSSOVER_TYPE),
getParameterAsDouble(PARAMETER_P_CROSSOVER), random);
// mutation
AggregationMutation mutation = new AggregationMutation(random);
// selection
int selectionType = getParameterAsInt(PARAMETER_SELECTION_TYPE);
AggregationSelection selection = null;
switch (selectionType) {
case SELECTION_TOURNAMENT:
selection = new AggregationTournamentSelection(popSize, getParameterAsDouble(PARAMETER_TOURNAMENT_FRACTION),
random);
break;
case SELECTION_MO:
selection = new AggregationNonDominatedSortingSelection(popSize);
break;
}
// initial population
List<AggregationIndividual> population = createInitialPopulation(popSize, exampleSet.getAttributes().size(), random);
getProgress().setTotal(maxGeneration);
// start optimization loop
while (!solutionGoodEnough()) {
getProgress().setCompleted(generation++);
crossover.crossover(population);
mutation.mutate(population);
evaluate(population, exampleSet);
selection.performSelection(population);
inApplyLoop();
}
// write criteria data of the final population into a file
if (isParameterSet(PARAMETER_POPULATION_CRITERIA_DATA_FILE)) {
File outFile = getParameterAsFile(PARAMETER_POPULATION_CRITERIA_DATA_FILE, true);
PerformanceVector prototype = population.get(0).getPerformance();
SimpleDataTable finalStatistics = new SimpleDataTable("Population", prototype.getCriteriaNames());
for (int i = 0; i < population.size(); i++) {
StringBuffer id = new StringBuffer(i + " (");
PerformanceVector current = population.get(i).getPerformance();
double[] data = new double[current.getSize()];
for (int d = 0; d < data.length; d++) {
data[d] = current.getCriterion(d).getFitness();
if (d != 0) {
id.append(", ");
}
id.append(Tools.formatNumber(data[d]));
}
id.append(")");
finalStatistics.add(new SimpleDataTableRow(data, id.toString()));
}
try (PrintWriter out = new PrintWriter(new FileWriter(outFile))) {
finalStatistics.write(out);
} catch (IOException e) {
throw new UserError(this, e, 303, new Object[] { outFile, e.getMessage() });
}
}
getProgress().complete();
// return result
evaluate(population, exampleSet);
Iterator<AggregationIndividual> i = population.iterator();
AggregationIndividual bestEver = null;
PerformanceVector bestPerformance = null;
while (i.hasNext()) {
AggregationIndividual current = i.next();
PerformanceVector currentPerf = current.getPerformance();
if (bestPerformance == null || currentPerf.compareTo(bestPerformance) > 0) {
bestPerformance = currentPerf;
bestEver = current;
}
}
exampleSetOutput.deliver(bestEver.createExampleSet(exampleSet, allAttributes, generator));
performanceOutput.deliver(bestPerformance);
}
// ================================================================================
private List<AggregationIndividual> createInitialPopulation(int popSize, int individualSize, Random random) {
List<AggregationIndividual> population = new ArrayList<AggregationIndividual>();
for (int i = 0; i < popSize; i++) {
int[] individual = new int[individualSize];
for (int a = 0; a < individual.length; a++) {
if (random.nextBoolean()) {
individual[a] = 0;
} else {
individual[a] = 1;
}
}
population.add(new AggregationIndividual(individual));
}
return population;
}
/** Returns true if the maximum number of generations was reached. */
private boolean solutionGoodEnough() {
if (generation >= maxGeneration) {
return true;
} else {
return false;
}
}
/**
* Creates example sets from all individuals and invoke the inner operators in order to estimate
* the performance.
*/
public void evaluate(List<AggregationIndividual> population, ExampleSet originalExampleSet) throws OperatorException {
Iterator<AggregationIndividual> i = population.iterator();
while (i.hasNext()) {
AggregationIndividual individual = i.next();
if (individual.getPerformance() == null) {
ExampleSet exampleSet = individual.createExampleSet(originalExampleSet, allAttributes, generator);
if (exampleSet.getAttributes().size() == 0) {
i.remove();
} else {
innerExampleSetSource.deliver(exampleSet);
getSubprocess(0).execute();
PerformanceVector performanceVector = innerPerformanceSink.getData(PerformanceVector.class);
individual.setPerformance(performanceVector);
}
}
}
}
@Override
public List<ParameterType> getParameterTypes() {
List<ParameterType> types = super.getParameterTypes();
ParameterType type = new ParameterTypeCategory(PARAMETER_AGGREGATION_FUNCTION,
"The aggregation function which is used for feature aggregations.", AGGREGATION_FUNCTIONS, AGGREGATION_MAX);
type.setExpert(false);
types.add(type);
type = new ParameterTypeInt(PARAMETER_POPULATION_SIZE, "Number of individuals per generation.", 1, Integer.MAX_VALUE,
10);
type.setExpert(false);
types.add(type);
type = new ParameterTypeInt(PARAMETER_MAXIMUM_NUMBER_OF_GENERATIONS,
"Number of generations after which to terminate the algorithm.", 1, Integer.MAX_VALUE, 100);
type.setExpert(false);
types.add(type);
type = new ParameterTypeCategory(PARAMETER_SELECTION_TYPE, "The type of selection.", SELECTION_TYPES,
SELECTION_TOURNAMENT);
types.add(type);
type = new ParameterTypeDouble(PARAMETER_TOURNAMENT_FRACTION,
"The fraction of the population which will participate in each tournament.", 0.0d, 1.0d, 0.2d);
type.registerDependencyCondition(
new EqualTypeCondition(this, PARAMETER_SELECTION_TYPE, SELECTION_TYPES, false, SELECTION_TOURNAMENT));
types.add(type);
types.add(new ParameterTypeCategory(PARAMETER_CROSSOVER_TYPE, "The type of crossover.",
AggregationCrossover.CROSSOVER_TYPES, AggregationCrossover.CROSSOVER_UNIFORM));
types.add(new ParameterTypeDouble(PARAMETER_P_CROSSOVER,
"Probability for an individual to be selected for crossover.", 0.0d, 1.0d, 0.9d));
types.add(new ParameterTypeFile(PARAMETER_POPULATION_CRITERIA_DATA_FILE,
"The path to the file in which the criteria data of the final population should be saved.", "crit", true));
types.addAll(RandomGenerator.getRandomGeneratorParameters(this));
return types;
}
@Override
public OperatorVersion[] getIncompatibleVersionChanges() {
OperatorVersion[] incompatibleVersionChanges = super.getIncompatibleVersionChanges();
OperatorVersion[] newIncompatibleVersionChanges = Arrays.copyOf(incompatibleVersionChanges,
incompatibleVersionChanges.length + 1);
newIncompatibleVersionChanges[newIncompatibleVersionChanges.length - 1] = CHANGE_7_3_1_NUMBER_OF_GENERATIONS;
return newIncompatibleVersionChanges;
}
}