/* * Copyright [2013-2015] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.dvarsel.wrapper; import ml.shifu.shifu.core.dvarsel.CandidatePerf; import ml.shifu.shifu.core.dvarsel.CandidatePopulation; import ml.shifu.shifu.core.dvarsel.CandidateSeed; import ml.shifu.shifu.core.dvarsel.VarSelWorkerResult; import ml.shifu.shifu.exception.ShifuErrorCode; import ml.shifu.shifu.exception.ShifuException; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.ListUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.*; import java.util.Map.Entry; public class CandidateGenerator { private static final Logger LOG = LoggerFactory.getLogger(CandidateGenerator.class); public static final String WORKER_SAMPLE_RATE = "worker_sample_rate"; public static final String POPULATION_MULTIPLY_CNT = "population_multiply_cnt"; public static final String POPULATION_LIVE_SIZE = "population_live_size"; public static final String EXPECT_VARIABLE_CNT = "expect_variable_cnt"; public static final String HYBRID_PERCENT = "hybrid_percent"; public static final String MUTATION_PERCENT = "mutation_percent"; private final int iteratorSeedCount; private final int expectVariableCount; private final int expectIterationCount; private final List<Integer> variables; private int inheritPercent; private int crossPercent; private int seedId = 1; private Random rd = new Random(System.currentTimeMillis()); public CandidateGenerator(Map<String, Object> params, List<Integer> variables) { this.expectIterationCount = (Integer) params.get(POPULATION_MULTIPLY_CNT); this.iteratorSeedCount = (Integer) params.get(POPULATION_LIVE_SIZE); if (this.iteratorSeedCount < 1) { LOG.error("Iterator seed count should be larger than 1."); throw new ShifuException(ShifuErrorCode.ERROR_SHIFU_CONFIG, "Iterator seed count should be larger than 1."); } this.expectVariableCount = (Integer) params.get(EXPECT_VARIABLE_CNT); if (this.expectVariableCount < 1) { LOG.error("Expect variable count should be larger than 1."); throw new ShifuException(ShifuErrorCode.ERROR_SHIFU_CONFIG, "Expect variable count should be larger than 1."); } this.variables = variables; this.crossPercent = (Integer) params.get(HYBRID_PERCENT); if (this.crossPercent < 0 || this.crossPercent > 100) { LOG.error("Cross percent should be larger than 0 and less than 100"); throw new ShifuException(ShifuErrorCode.ERROR_SHIFU_CONFIG, "Cross percent should be larger than 0 and less than 100."); } int mutationPercent = (Integer) params.get(MUTATION_PERCENT); if (mutationPercent < 0 || mutationPercent > 100) { LOG.error("Mutation percent should be larger than 0 and less 100"); throw new ShifuException(ShifuErrorCode.ERROR_SHIFU_CONFIG, "Mutation percent should be larger than 0 and less than 100."); } this.inheritPercent = 100 - crossPercent - mutationPercent; if (this.inheritPercent < 0 || this.inheritPercent > 100) { LOG.error("Cross percent add mutation percent should be larger than 0 and less than 100"); throw new ShifuException(ShifuErrorCode.ERROR_SHIFU_CONFIG, "Cross percent add mutation percent should be larger than 0 and less than 100."); } } public int getExpectIterationCount() { return expectIterationCount; } public CandidatePopulation initSeeds() { CandidatePopulation seeds = new CandidatePopulation(iteratorSeedCount); for (int seedIndex = 0; seedIndex < iteratorSeedCount; seedIndex++) { List<Integer> variableList = new ArrayList<Integer>(expectVariableCount); for (int varIndex = 0; varIndex < expectVariableCount; varIndex++) { variableList.add(randomVariable(rd, variableList)); } seeds.addCandidateSeed(new CandidateSeed(this.genSeedId(), variableList)); } return seeds; } private Integer randomVariable(Random random, List<Integer> variableList) { Integer variable = variables.get((int)(random.nextDouble() * (variables.size() - 1))); if (variableList.contains(variable)) { variable = randomVariable(random, variableList); } return variable; } public CandidatePopulation nextGeneration(Iterable<VarSelWorkerResult> workerResults, CandidatePopulation seeds) { if ( hasNoneResults(workerResults) ) { return seeds; } List<CandidatePerf> perfs = getIndividual(workerResults); Collections.sort(perfs, new Comparator<CandidatePerf>() { @Override public int compare(CandidatePerf cpa, CandidatePerf cpb) { return cpa.getVerror() < cpb.getVerror() ? -1 : 1; } }); for (int i = 0; i < 5; i++) { LOG.info("The error rate is {}, the best-{} seed: {} ", perfs.get(i).getVerror(), i, seeds.getSeedById(perfs.get(i).getId())); } LOG.info("Worst seed: {}", perfs.get(perfs.size() - 1).toString()); List<CandidatePerf> bestPerfs = perfs.subList(0, getLastBestIndex(perfs) + 1); List<CandidatePerf> ordinaryPerfs = perfs.subList(getLastBestIndex(perfs) + 1, getFistWorstIndex(perfs)); List<CandidatePerf> worstPerfs = perfs.subList(getFistWorstIndex(perfs), perfs.size()); List<CandidateSeed> bestSeeds = filter(seeds, bestPerfs); List<CandidateSeed> ordinarySeeds = filter(seeds, ordinaryPerfs); List<CandidateSeed> worstSeeds = filter(seeds, worstPerfs); CandidatePopulation result = new CandidatePopulation(iteratorSeedCount); result.addCandidateSeedList(inherit(bestSeeds)); result.addCandidateSeedList(hybrid(ordinarySeeds)); result.addCandidateSeedList(mutate(worstSeeds)); LOG.debug("new generation:" + result); return result; } private boolean hasNoneResults(Iterable<VarSelWorkerResult> workerResults) { for ( VarSelWorkerResult result : workerResults ) { if ( result.getSeedPerfList().size() > 0 ) { return false; } } return true; } private int getLastBestIndex(List<CandidatePerf> perfs) { return perfs.size() * inheritPercent / 100; } private int getFistWorstIndex(List<CandidatePerf> perfs) { return perfs.size() * (100 - crossPercent) / 100; } private List<CandidateSeed> inherit(List<CandidateSeed> bestSeeds) { return bestSeeds; } private List<CandidateSeed> hybrid(List<CandidateSeed> ordinarySeedList) { List<CandidateSeed> result = new ArrayList<CandidateSeed>(ordinarySeedList.size()); int childCnt = 0; while ( childCnt < ordinarySeedList.size() ) { CandidateSeed father = ordinarySeedList.get(rd.nextInt(ordinarySeedList.size())); CandidateSeed mather = ordinarySeedList.get(rd.nextInt(ordinarySeedList.size())); CandidateSeed child = hybrid(father, mather); if ( child != null ) { result.add(child); childCnt++; } } return result; } private CandidateSeed hybrid(CandidateSeed father, CandidateSeed mather) { Set<Integer> geneSet = new HashSet<Integer>(); geneSet.addAll(father.getColumnIdList()); geneSet.addAll(mather.getColumnIdList()); List<Integer> wholeGeneList = new ArrayList<Integer>(geneSet); List<Integer> indexList = new ArrayList<Integer>(wholeGeneList.size()); for ( int i = 0; i < wholeGeneList.size(); i ++ ) { indexList.add(i); } Collections.shuffle(indexList); List<Integer> childGeneList = new ArrayList<Integer>(father.getColumnIdList().size()); for ( int i = 0; i < father.getColumnIdList().size(); i ++ ) { childGeneList.add(wholeGeneList.get(indexList.get(i))); } return new CandidateSeed(this.genSeedId(), childGeneList); } private List<CandidateSeed> mutate(List<CandidateSeed> worstSeeds) { List<CandidateSeed> result = new ArrayList<CandidateSeed>(worstSeeds.size()); for ( CandidateSeed seed : worstSeeds ) { result.add(doMutation(seed)); } return result; } @SuppressWarnings("unchecked") private CandidateSeed doMutation(CandidateSeed seed) { List<Integer> geneList = new ArrayList<Integer>(); List<Integer> unselectedGeneList = ListUtils.subtract(variables, seed.getColumnIdList()); Collections.shuffle(unselectedGeneList); int replaceCnt = 0; for ( int i = 0; i < seed.getColumnIdList().size(); i ++ ) { if ( rd.nextDouble() < 0.05 ) { replaceCnt ++; } else { geneList.add(seed.getColumnIdList().get(i)); } } if ( replaceCnt > 0 ) { geneList.addAll(unselectedGeneList.subList(0, replaceCnt)); } return new CandidateSeed(this.genSeedId(), geneList); } private List<CandidateSeed> filter(CandidatePopulation seeds, final List<CandidatePerf> perfList) { List<CandidateSeed> result = new ArrayList<CandidateSeed>(perfList.size()); for ( CandidatePerf perf : perfList ) { result.add(seeds.getSeedById(perf.getId())); } return result; } private List<CandidatePerf> getIndividual(Iterable<VarSelWorkerResult> workerResults) { Map<Integer, List<Double>> errorMap = new HashMap<Integer, List<Double>>(); for (VarSelWorkerResult workerResult : workerResults) { List<CandidatePerf> seedPerfList = workerResult.getSeedPerfList(); for (CandidatePerf perf : seedPerfList) { if (!errorMap.containsKey(perf.getId())) { errorMap.put(perf.getId(), new ArrayList<Double>()); } errorMap.get(perf.getId()).add(perf.getVerror()); } } List<CandidatePerf> perfs = new ArrayList<CandidatePerf>(errorMap.size()); for (Entry<Integer, List<Double>> entry : errorMap.entrySet()) { double vError = mean(entry.getValue()); perfs.add(new CandidatePerf(entry.getKey(), vError)); } return perfs; } private double mean(List<Double> values) { if ( CollectionUtils.isEmpty(values)) { return 999.0; } double result = 0; for (Double value : values) { result += value; } return result / values.size(); } private int genSeedId() { return (seedId ++); } }