/* * Copyright [2013-2016] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.dtrain.gs; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.SortedMap; import java.util.TreeMap; import ml.shifu.shifu.core.processor.TrainModelProcessor; import ml.shifu.shifu.util.Environment; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Grid search supports for kinds of parameters set by List: [..., ..., ...]. * * <p> * {@link #getParams(int)} list real params can be set in train part. * * <p> * In {@link TrainModelProcessor} there is logic that to process all kind of hyper params. * * <p> * Only distributed model supports grid search functions. * * @author Zhang David (pengzhang@paypal.com) */ public class GridSearch { protected static final Logger LOG = LoggerFactory.getLogger(GridSearch.class); /** * Raw parameter from ModelConfig#train#params */ private final Map<String, Object> rawParams; /** * List all kinds of hyper parameters which can be used in training step. List is sorted by name + order in each * parameter. */ private List<Map<String, Object>> flattenParams = new ArrayList<Map<String, Object>>(); /** * How many hyper parameters */ private int hyperParamCount; /** * How many hyper parameter composite, size of {@link #flattenParams} */ private int flattenParamsCount; public GridSearch(Map<String, Object> rawParams) { assert rawParams != null; this.rawParams = rawParams; parseParams(this.rawParams); } @SuppressWarnings("rawtypes") private void parseParams(Map<String, Object> params) { // use sorted map to sort all parameters by natural order, this makes all flatten parameters sorted and fixed SortedMap<String, Object> sortedMap = new TreeMap<String, Object>(params); LOG.debug(sortedMap.toString()); List<Integer> hyperParamCntList = new ArrayList<Integer>(); Map<String, Object> normalParams = new HashMap<String, Object>(); List<Tuple> hyperParams = new ArrayList<GridSearch.Tuple>(); // stats on hyper parameters for(Entry<String, Object> entry: sortedMap.entrySet()) { if(entry.getKey().equals("ActivationFunc") || entry.getKey().equals("NumHiddenNodes")) { if(entry.getValue() instanceof List) { if(((List) (entry.getValue())).size() > 0 && ((List) (entry.getValue())).get(0) instanceof List) { // ActivationFunc and NumHiddenNodes in NN is already List, so as hyper parameter they should be // list of list. this.hyperParamCount += 1; hyperParams.add(new Tuple(entry.getKey(), entry.getValue())); hyperParamCntList.add(((List) entry.getValue()).size()); } else { // else as normal params normalParams.put(entry.getKey(), entry.getValue()); } } continue; } else if(entry.getValue() instanceof List) { this.hyperParamCount += 1; hyperParams.add(new Tuple(entry.getKey(), entry.getValue())); hyperParamCntList.add(((List) entry.getValue()).size()); } else { normalParams.put(entry.getKey(), entry.getValue()); } } // TODO parameter validation if(hasHyperParam()) { // compute all kinds hyper parameter composite and set into flatten Params // TODO, do we need a threshold like 30 since the cost of grid search is high this.flattenParamsCount = 1; for(Integer cnt: hyperParamCntList) { this.flattenParamsCount *= cnt; } // construct flatten params map for(int i = 0; i < this.flattenParamsCount; i++) { Map<String, Object> map = new HashMap<String, Object>(); int amplifier = 1; // find hyper parameters for(int j = hyperParamCntList.size() - 1; j >= 0; j--) { int currParamCnt = hyperParamCntList.get(j); Tuple tuple = hyperParams.get(j); Object value = ((List) (tuple.value)).get(i / amplifier % currParamCnt); map.put(tuple.key, value); amplifier *= currParamCnt; } // put normal parameters for(Entry<String, Object> entry: normalParams.entrySet()) { map.put(entry.getKey(), entry.getValue()); } this.flattenParams.add(map); } // random search if over threshold int threshold = Environment.getInt("shifu.gridsearch.threshold", 30); if(this.flattenParamsCount > threshold) { // set random search size is threshold LOG.info("Grid search numer is over threshold {}, leverage randomize search.", threshold); this.flattenParamsCount = threshold; List<Map<String, Object>> oldFlattenParams = this.flattenParams; this.flattenParams = new ArrayList<Map<String, Object>>(threshold); // just to select fixed number of elements, not random to make it can be called twice and return the // same result; int mod = oldFlattenParams.size() % threshold; int factor = oldFlattenParams.size() / threshold; for(int i = 0; i < threshold; i++) { if(i > (threshold - 1 - mod)) { this.flattenParams.add(oldFlattenParams.get((factor + 1) * i - (threshold - mod))); } else { this.flattenParams.add(oldFlattenParams.get(factor * i)); } } } } } public int hyperParamCount() { return this.hyperParamCount; } public Map<String, Object> getParams(int i) { return this.flattenParams.get(i); } public List<Map<String, Object>> getFlattenParams() { return this.flattenParams; } public boolean hasHyperParam() { return this.hyperParamCount > 0; } public boolean isGridSearchMode() { return this.hyperParamCount > 0; } private static class Tuple { public Tuple(String key, Object value) { this.key = key; this.value = value; } public String key; public Object value; } }