/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.ga.watchmaker.cd;
import java.io.IOException;
import java.util.List;
import com.google.common.collect.Lists;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.hadoop.fs.Path;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.ga.watchmaker.cd.hadoop.CDMahoutEvaluator;
import org.apache.mahout.ga.watchmaker.cd.hadoop.DatasetSplit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.uncommons.watchmaker.framework.CandidateFactory;
import org.uncommons.watchmaker.framework.EvolutionEngine;
import org.uncommons.watchmaker.framework.EvolutionObserver;
import org.uncommons.watchmaker.framework.EvolutionaryOperator;
import org.uncommons.watchmaker.framework.FitnessEvaluator;
import org.uncommons.watchmaker.framework.PopulationData;
import org.uncommons.watchmaker.framework.SelectionStrategy;
import org.uncommons.watchmaker.framework.SequentialEvolutionEngine;
import org.uncommons.watchmaker.framework.operators.EvolutionPipeline;
import org.uncommons.watchmaker.framework.selection.RouletteWheelSelection;
import org.uncommons.watchmaker.framework.termination.GenerationCount;
/**
* Class Discovery Genetic Algorithm main class. Has the following parameters:
* <ul>
* <li>threshold<br>
* Condition activation threshold. See Also {@link org.apache.mahout.ga.watchmaker.cd.CDRule CDRule}
* <li>nb cross point<br>
* Number of points used by the{@link org.apache.mahout.ga.watchmaker.cd.CDCrossover CrossOver} operator
* <li>mutation rate<br>
* mutation rate of the {@link org.apache.mahout.ga.watchmaker.cd.CDMutation Mutation} operator
* <li>mutation range<br>
* mutation range of the {@link org.apache.mahout.ga.watchmaker.cd.CDMutation Mutation} operator
* <li>mutation precision<br>
* mutation precision of the {@link org.apache.mahout.ga.watchmaker.cd.CDMutation Mutation} operator
* <li>population size
* <li>generations count<br>
* number of generations the genetic algorithm will be run for.
*
* </ul>
*/
public final class CDGA {
private static final Logger log = LoggerFactory.getLogger(CDGA.class);
private CDGA() {
}
public static void main(String[] args) throws IOException, InterruptedException, ClassNotFoundException {
DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
ArgumentBuilder abuilder = new ArgumentBuilder();
GroupBuilder gbuilder = new GroupBuilder();
Option inputOpt = DefaultOptionCreator.inputOption().create();
Option labelOpt = obuilder.withLongName("label").withRequired(true).withShortName("l")
.withArgument(abuilder.withName("index").withMinimum(1).withMaximum(1).create())
.withDescription("label's index.").create();
Option thresholdOpt = obuilder.withLongName("threshold").withRequired(false).withShortName("t").withArgument(
abuilder.withName("threshold").withMinimum(1).withMaximum(1).create()).withDescription(
"Condition activation threshold, default = 0.5.").create();
Option crosspntsOpt = obuilder.withLongName("crosspnts").withRequired(false).withShortName("cp").withArgument(
abuilder.withName("points").withMinimum(1).withMaximum(1).create()).withDescription(
"Number of crossover points to use, default = 1.").create();
Option mutrateOpt = obuilder.withLongName("mutrate").withRequired(true).withShortName("m").withArgument(
abuilder.withName("true").withMinimum(1).withMaximum(1).create())
.withDescription("Mutation rate (float).").create();
Option mutrangeOpt = obuilder.withLongName("mutrange").withRequired(false).withShortName("mr").withArgument(
abuilder.withName("range").withMinimum(1).withMaximum(1).create())
.withDescription("Mutation range, default = 0.1 (10%).").create();
Option mutprecOpt = obuilder.withLongName("mutprec").withRequired(false).withShortName("mp").withArgument(
abuilder.withName("precision").withMinimum(1).withMaximum(1).create())
.withDescription("Mutation precision, default = 2.").create();
Option popsizeOpt = obuilder.withLongName("popsize").withRequired(true).withShortName("p").withArgument(
abuilder.withName("size").withMinimum(1).withMaximum(1).create()).withDescription("Population size.").create();
Option gencntOpt = obuilder.withLongName("gencnt").withRequired(true).withShortName("g").withArgument(
abuilder.withName("count").withMinimum(1).withMaximum(1).create())
.withDescription("Generations count.").create();
Option helpOpt = DefaultOptionCreator.helpOption();
Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(helpOpt).withOption(labelOpt)
.withOption(thresholdOpt).withOption(crosspntsOpt).withOption(mutrateOpt).withOption(mutrangeOpt)
.withOption(mutprecOpt).withOption(popsizeOpt).withOption(gencntOpt).create();
Parser parser = new Parser();
parser.setGroup(group);
try {
CommandLine cmdLine = parser.parse(args);
if (cmdLine.hasOption(helpOpt)) {
CommandLineUtil.printHelp(group);
return;
}
String dataset = cmdLine.getValue(inputOpt).toString();
int target = Integer.parseInt(cmdLine.getValue(labelOpt).toString());
double threshold =
cmdLine.hasOption(thresholdOpt) ? Double.parseDouble(cmdLine.getValue(thresholdOpt).toString()) : 0.5;
int crosspnts =
cmdLine.hasOption(crosspntsOpt) ? Integer.parseInt(cmdLine.getValue(crosspntsOpt).toString()) : 1;
double mutrate = Double.parseDouble(cmdLine.getValue(mutrateOpt).toString());
double mutrange =
cmdLine.hasOption(mutrangeOpt) ? Double.parseDouble(cmdLine.getValue(mutrangeOpt).toString()) : 0.1;
int mutprec = cmdLine.hasOption(mutprecOpt) ? Integer.parseInt(cmdLine.getValue(mutprecOpt).toString()) : 2;
int popSize = Integer.parseInt(cmdLine.getValue(popsizeOpt).toString());
int genCount = Integer.parseInt(cmdLine.getValue(gencntOpt).toString());
long start = System.currentTimeMillis();
runJob(dataset, target, threshold, crosspnts, mutrate, mutrange, mutprec, popSize, genCount);
long end = System.currentTimeMillis();
printElapsedTime(end - start);
} catch (OptionException e) {
log.error("Error while parsing options", e);
CommandLineUtil.printHelp(group);
}
}
private static void runJob(String dataset,
int target,
double threshold,
int crosspnts,
double mutrate,
double mutrange,
int mutprec,
int popSize,
int genCount) throws IOException, InterruptedException, ClassNotFoundException {
Path inpath = new Path(dataset);
CDMahoutEvaluator.initializeDataSet(inpath);
// Candidate Factory
CandidateFactory<CDRule> factory = new CDFactory(threshold);
// Evolution Scheme
List<EvolutionaryOperator<CDRule>> operators = Lists.newArrayList();
operators.add(new CDCrossover(crosspnts));
operators.add(new CDMutation(mutrate, mutrange, mutprec));
EvolutionPipeline<CDRule> pipeline = new EvolutionPipeline<CDRule>(operators);
// 75 % of the dataset is dedicated to training
DatasetSplit split = new DatasetSplit(0.75);
// Fitness Evaluator (defaults to training)
FitnessEvaluator<? super CDRule> evaluator = new CDFitnessEvaluator(dataset, target, split);
// Selection Strategy
SelectionStrategy<? super CDRule> selection = new RouletteWheelSelection();
EvolutionEngine<CDRule> engine =
new SequentialEvolutionEngine<CDRule>(factory, pipeline, evaluator, selection, RandomUtils.getRandom());
engine.addEvolutionObserver(new EvolutionObserver<CDRule>() {
@Override
public void populationUpdate(PopulationData<? extends CDRule> data) {
log.info("Generation {}", data.getGenerationNumber());
}
});
// evolve the rules over the training set
Rule solution = engine.evolve(popSize, 1, new GenerationCount(genCount));
Path output = new Path("output");
// fitness over the training set
CDFitness bestTrainFit = CDMahoutEvaluator.evaluate(solution, target, inpath, output, split);
// fitness over the testing set
split.setTraining(false);
CDFitness bestTestFit = CDMahoutEvaluator.evaluate(solution, target, inpath, output, split);
// evaluate the solution over the testing set
log.info("Best solution fitness (train set) : {}", bestTrainFit);
log.info("Best solution fitness (test set) : {}", bestTestFit);
}
private static void printElapsedTime(long milli) {
long seconds = milli / 1000;
milli %= 1000;
long minutes = seconds / 60;
seconds %= 60;
long hours = minutes / 60;
minutes %= 60;
log.info("Elapsed time (Hours:minutes:seconds:milli) : {}:{}:{}:{}", new Object[] {hours, minutes, seconds, milli});
}
}