/*
* Copyright [2012-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.container.meta;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import ml.shifu.shifu.container.obj.EvalConfig;
import ml.shifu.shifu.container.obj.ModelBasicConf;
import ml.shifu.shifu.container.obj.ModelConfig;
import ml.shifu.shifu.container.obj.ModelNormalizeConf;
import ml.shifu.shifu.container.obj.ModelSourceDataConf;
import ml.shifu.shifu.container.obj.ModelStatsConf;
import ml.shifu.shifu.container.obj.ModelTrainConf;
import ml.shifu.shifu.container.obj.ModelVarSelectConf;
import ml.shifu.shifu.core.dtrain.gs.GridSearch;
import ml.shifu.shifu.util.Constants;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang.StringUtils;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.Sets;
import com.google.common.math.DoubleMath;
/**
* MetaFactory class
* MetaFactory hosts all the meta for ModelConfig.
* It provides the capability to validate ModelConfig object and sub-fields.
* It can also provide deep copy of all meta
*/
public class MetaFactory {
// flag to indicate validate result
public static final String VALIDATE_OK = "OK";
// key related variables
public static final String ITEM_KEY_SEPERATOR = "#";
public static final String DUMMY = "dummy";
// the tags for group
private static final String BASIC_TAG = "basic";
private static final String DATASET_TAG = "dataSet";
private static final String STATS_TAG = "stats";
private static final String VARSELECT_TAG = "varSelect";
private static final String NORMALIZE_TAG = "normalize";
private static final String TRAIN_TAG = "train";
private static final String EVALS_TAG = "evals";
// default MetaConfig input file
public static final String MODEL_META_STORE_FILE = "store/ModelConfigMeta.json";
// warehouse for meta item
private static Map<String, MetaItem> itemsWareHouse;
/**
* Load the MetalConfig into Memory, and organize them in flatten format so that user
* can use 'key' to query MetaItem.
*
* Please note, the key is composed by 'group'#'name'
*
*/
static {
ObjectMapper jsonMapper = new ObjectMapper();
MetaGroup[] groups = null;
try {
groups = jsonMapper.readValue(MetaFactory.class.getClassLoader().getResource(MODEL_META_STORE_FILE),
MetaGroup[].class);
} catch (Exception e) {
throw new RuntimeException("Fail to read model meta from " + MODEL_META_STORE_FILE, e);
}
itemsWareHouse = new HashMap<String, MetaItem>();
if(groups != null && groups.length > 0) {
for(MetaGroup metaGroup: groups) {
for(MetaItem metaItem: metaGroup.getMetaList()) {
String key = metaGroup.getGroup() + ITEM_KEY_SEPERATOR + metaItem.getName();
addMetaItem(key, metaItem);
}
}
}
}
/**
* Get a copy of itemsWarehouse. Deep copy, so that user couldn't change the content of constrain
*
* @return a stand-alone copy of @itemsWarehouse
*/
public static Map<String, MetaItem> getModelConfigMeta() {
Map<String, MetaItem> retMap = new HashMap<String, MetaItem>();
Iterator<Entry<String, MetaItem>> iterator = itemsWareHouse.entrySet().iterator();
while(iterator.hasNext()) {
Entry<String, MetaItem> entry = iterator.next();
retMap.put(entry.getKey(), entry.getValue().clone());
}
return retMap;
}
/**
* Validate the ModelConfig object, to make sure each item follow the constrain
*
* @param modelConfig
* - object to validate
* @return ValidateResult
* If all items are OK, the ValidateResult.status will be true;
* Or the ValidateResult.status will be false, ValidateResult.causes will contain the reasons
* @throws Exception
* any exception in validaiton
*/
public static ValidateResult validate(ModelConfig modelConfig) throws Exception {
ValidateResult result = new ValidateResult(true);
GridSearch gs = new GridSearch(modelConfig.getTrain().getParams());
Class<?> cls = modelConfig.getClass();
Field[] fields = cls.getDeclaredFields();
for(Field field: fields) {
// skip log instance
if(field.getName().equalsIgnoreCase("log") || field.getName().equalsIgnoreCase("logger")) {
continue;
}
if(!field.isSynthetic()) {
Method method = cls.getMethod("get" + getMethodName(field.getName()));
Object value = method.invoke(modelConfig);
if(value instanceof List) {
List<?> objList = (List<?>) value;
for(Object obj: objList) {
encapsulateResult(result, iterateCheck(gs.hasHyperParam(), field.getName(), obj));
}
} else {
encapsulateResult(result, iterateCheck(gs.hasHyperParam(), field.getName(), value));
}
}
}
return result;
}
/**
* Validate the ModelBasicConf object, to make sure each item follow the constrain
*
* @param basic
* - object to validate
* @return ValidateResult
* If all items are OK, the ValidateResult.status will be true;
* Or the ValidateResult.status will be false, ValidateResult.causes will contain the reasons
* @throws Exception
* any exception in validaiton
*/
public static ValidateResult validate(ModelBasicConf basic) throws Exception {
return iterateCheck(false, BASIC_TAG, basic);
}
/**
* Validate the ModelSourceDataConf object, to make sure each item follow the constrain
*
* @param sourceData
* - object to validate
* @return ValidateResult
* If all items are OK, the ValidateResult.status will be true;
* Or the ValidateResult.status will be false, ValidateResult.causes will contain the reasons
* @throws Exception
* any exception in validaiton
*/
public static ValidateResult validate(ModelSourceDataConf sourceData) throws Exception {
return iterateCheck(false, DATASET_TAG, sourceData);
}
/**
* Validate the ModelStatsConf object, to make sure each item follow the constrain
*
* @param stats
* - object to validate
* @return ValidateResult
* If all items are OK, the ValidateResult.status will be true;
* Or the ValidateResult.status will be false, ValidateResult.causes will contain the reasons
* @throws Exception
* any exception in validaiton
*/
public static ValidateResult validate(ModelStatsConf stats) throws Exception {
return iterateCheck(false, STATS_TAG, stats);
}
/**
* Validate the ModelVarSelectConf object, to make sure each item follow the constrain
*
* @param varselect
* - object to validate
* @return ValidateResult
* If all items are OK, the ValidateResult.status will be true;
* Or the ValidateResult.status will be false, ValidateResult.causes will contain the reasons
* @throws Exception
* any exception in validaiton
*/
public static ValidateResult validate(ModelVarSelectConf varselect) throws Exception {
return iterateCheck(false, VARSELECT_TAG, varselect);
}
/**
* Validate the ModelNormalizeConf object, to make sure each item follow the constrain
*
* @param normalizer
* - object to validate
* @return ValidateResult
* If all items are OK, the ValidateResult.status will be true;
* Or the ValidateResult.status will be false, ValidateResult.causes will contain the reasons
* @throws Exception
* any exception in validaiton
*/
public static ValidateResult validate(ModelNormalizeConf normalizer) throws Exception {
return iterateCheck(false, NORMALIZE_TAG, normalizer);
}
/**
* Validate the ModelTrainConf object, to make sure each item follow the constrain
*
* @param train
* - object to validate
* @return ValidateResult
* If all items are OK, the ValidateResult.status will be true;
* Or the ValidateResult.status will be false, ValidateResult.causes will contain the reasons
* @throws Exception
* any exception in validaiton
*/
public static ValidateResult validate(ModelTrainConf train) throws Exception {
return iterateCheck(false, TRAIN_TAG, train);
}
/**
* Validate the List(EvalConfig) object, to make sure each item follow the constrain
*
* @param evalList
* - object list to validate
* @return ValidateResult
* If all items are OK, the ValidateResult.status will be true;
* Or the ValidateResult.status will be false, ValidateResult.causes will contain the reasons
* @throws Exception
* any exception in validaiton
*/
public static ValidateResult validate(List<EvalConfig> evalList) throws Exception {
ValidateResult result = new ValidateResult(true);
for(EvalConfig evalConfig: evalList) {
encapsulateResult(result, validate(evalConfig));
}
return result;
}
/**
* Validate the EvalConfig, to make sure each item follow the constrain
*
* @param eval
* - object to validate
* @return ValidateResult
* If all items are OK, the ValidateResult.status will be true;
* Or the ValidateResult.status will be false, ValidateResult.causes will contain the reasons
* @throws Exception
* any exception in validaiton
*/
public static ValidateResult validate(EvalConfig eval) throws Exception {
return iterateCheck(false, EVALS_TAG, eval);
}
/**
* Iterate each property of Object, get the value and validate
*
* @param isGridSearch
* - if grid search, ignore validation in train#params as they are set all as list
* @param ptag
* - the prefix of key to search @MetaItem
* @param obj
* - the object to validate
* @return ValidateResult
* If all items are OK, the ValidateResult.status will be true;
* Or the ValidateResult.status will be false, ValidateResult.causes will contain the reasons
* @throws Exception
* any exception in validaiton
*/
public static ValidateResult iterateCheck(boolean isGridSearch, String ptag, Object obj) throws Exception {
ValidateResult result = new ValidateResult(true);
if(obj == null) {
return result;
}
Class<?> cls = obj.getClass();
Field[] fields = cls.getDeclaredFields();
Class<?> parentCls = cls.getSuperclass();
if(!parentCls.equals(Object.class)) {
Field[] pfs = parentCls.getDeclaredFields();
fields = (Field[]) ArrayUtils.addAll(fields, pfs);
}
for(Field field: fields) {
if(!field.isSynthetic() && !Modifier.isStatic(field.getModifiers()) && !isJsonIngoreField(field)) {
Method method = cls.getMethod("get" + getMethodName(field.getName()));
Object value = method.invoke(obj);
encapsulateResult(result, validate(isGridSearch, ptag + ITEM_KEY_SEPERATOR + field.getName(), value));
}
}
return result;
}
private static boolean isJsonIngoreField(Field field) {
Annotation[] annotations = field.getAnnotations();
for(Annotation annotation: annotations) {
if(annotation.annotationType().getName().equals(JsonIgnore.class.getName())) {
return true;
}
}
return false;
}
static Set<String> filterSet = Sets.newHashSet(new String[] { "NumHiddenLayers", "ActivationFunc",
"NumHiddenNodes", "LearningRate", "DropoutRate", "RegularizedConstant", "L1orL2", "MaxDepth",
"MinInstancesPerNode", "MinInfoGain", "MaxStatsMemoryMB", "TreeNum", "Impurity", "FeatureSubsetStrategy",
"Loss", "LearningDecay", "Propagation", "GBTSampleWithReplacement", "Kernel", "Const", "Gamma",
"EnableEarlyStop", "ValidationTolerance", "MaxLeaves", "MaxBatchSplitSize" });
// ugly code for grid search
private static boolean filterOut(String itemKey) {
String str = itemKey;
if(str.contains("#")) {
String[] strList = str.split("#");
str = strList[strList.length - 1];
}
return filterSet.contains(str);
}
/**
* Validate the input value. Find the @MetaItem from warehouse, and do the validation
*
* @param isGridSearch
* - if grid search, ignore validation in train#params as they are set all as list
* @param itemKey
* - the key to locate MetaItem
* @param itemValue
* - the value to validate
* @return if validate OK, return "OK"
* or return the cause - String
* @throws Exception
* any exception in validaiton
*/
public static String validate(boolean isGridSearch, String itemKey, Object itemValue) throws Exception {
MetaItem itemMeta = itemsWareHouse.get(itemKey);
if(isGridSearch && filterOut(itemKey)) {
return VALIDATE_OK;
}
if(itemMeta == null) {
return itemKey + " - not found meta info.";
}
if(itemMeta.getType().equals("text")) {
String value = ((itemValue == null) ? null : itemValue.toString());
if(itemMeta.getMaxLength() != null && value != null && value.length() > itemMeta.getMaxLength()) {
return itemKey + " - the length of value exceeds the max length : " + itemMeta.getMaxLength();
}
if(itemMeta.getMinLength() != null && (value == null || value.length() < itemMeta.getMinLength())) {
if(value == null) {
return itemKey + " - then shouldn't be null";
} else {
return itemKey + " - the length of value less than min length : " + itemMeta.getMinLength();
}
}
if(CollectionUtils.isNotEmpty(itemMeta.getOptions())) {
boolean isOptionValue = false;
for(ValueOption itemOption: itemMeta.getOptions()) {
String optValue = (String) itemOption.getValue();
if(optValue.equalsIgnoreCase(value)) {
isOptionValue = true;
break;
}
}
if(!isOptionValue) {
return itemKey + " - the value couldn't be found in the option value list - "
+ convertOptionIntoString(itemMeta.getOptions());
}
}
} else if(itemMeta.getType().equals("number")) {
if(itemValue == null) {
if(CollectionUtils.isNotEmpty(itemMeta.getOptions())) {
return itemKey + " - the value couldn't be null.";
}
} else {
Double value = null;
try {
value = Double.valueOf(itemValue.toString());
} catch (NumberFormatException e) {
return itemKey + " - the value is not number format.";
}
if(value != null && CollectionUtils.isNotEmpty(itemMeta.getOptions())) {
boolean isOptionValue = false;
for(ValueOption itemOption: itemMeta.getOptions()) {
Double optValue = Double.valueOf(itemOption.getValue().toString());
if(DoubleMath.fuzzyEquals(value, optValue, Constants.TOLERANCE)) {
isOptionValue = true;
break;
}
}
if(!isOptionValue) {
return itemKey + " - the value couldn't be found in the option value list - "
+ convertOptionIntoString(itemMeta.getOptions());
}
}
}
} else if(itemMeta.getType().equals("boolean")) {
if(itemValue == null) {
return itemKey + " - the value couldn't be null. Only true/false are perimited.";
}
if(!itemValue.toString().equalsIgnoreCase("true") && !itemValue.toString().equalsIgnoreCase("false")) {
return itemKey + " - the value is illegal. Only true/false are perimited.";
}
} else if(itemMeta.getType().equals("list")) {
if(itemValue != null && itemMeta.getElement() != null) {
@SuppressWarnings("unchecked")
List<Object> valueList = (List<Object>) itemValue;
for(Object obj: valueList) {
if(itemMeta.getElementType().equals("object")) {
ValidateResult result = iterateCheck(isGridSearch, itemKey, obj);
if(!result.getStatus()) {
return result.getCauses().get(0);
}
} else {
String validateStr = validate(isGridSearch, itemKey + ITEM_KEY_SEPERATOR + DUMMY, obj);
if(!validateStr.equals(VALIDATE_OK)) {
return validateStr;
}
}
}
}
} else if(itemMeta.getType().equals("map")) {
if(itemValue != null && itemMeta.getElement() != null) {
@SuppressWarnings("unchecked")
Map<String, Object> valueMap = (Map<String, Object>) itemValue;
Iterator<Entry<String, Object>> iterator = valueMap.entrySet().iterator();
while(iterator.hasNext()) {
Entry<String, Object> entry = iterator.next();
String key = entry.getKey();
Object value = entry.getValue();
String validateStr = validate(isGridSearch, itemKey + ITEM_KEY_SEPERATOR + key, value);
if(!validateStr.equals(VALIDATE_OK)) {
return validateStr;
}
}
}
}
return VALIDATE_OK;
}
/**
* Add the MetaItem into warehouse.
* If the type of MetaItem is list, try to add the child elements
*
* @param key
* - the key to store MetaItem
* @param metaItem
* - object to store
*/
private static void addMetaItem(String key, MetaItem metaItem) {
itemsWareHouse.put(key, metaItem);
if(StringUtils.equals(metaItem.getType(), "list")) {
if(StringUtils.equals(metaItem.getElementType(), "object")) {
if(CollectionUtils.isNotEmpty(metaItem.getElement())) {
for(MetaItem sub: metaItem.getElement()) {
addMetaItem(key + ITEM_KEY_SEPERATOR + sub.getName(), sub);
}
}
} else {
if(CollectionUtils.isNotEmpty(metaItem.getElement())) {
MetaItem sub = metaItem.getElement().get(0);
addMetaItem(key + ITEM_KEY_SEPERATOR + DUMMY, sub);
}
}
} else if(StringUtils.equals(metaItem.getType(), "map")) {
if(CollectionUtils.isNotEmpty(metaItem.getElement())) {
for(MetaItem sub: metaItem.getElement()) {
addMetaItem(key + ITEM_KEY_SEPERATOR + sub.getName(), sub);
}
}
} else if(StringUtils.equals(metaItem.getType(), "object")) {
if(CollectionUtils.isNotEmpty(metaItem.getElement())) {
for(MetaItem sub: metaItem.getElement()) {
addMetaItem(key + ITEM_KEY_SEPERATOR + sub.getName(), sub);
}
}
}
}
/**
* Convert the value of ValueOption list into String
* For example, if the mode options are ["local", "hdfs"], the output will be local/hdfs
*
* @param options
* - ValueOption list
* @return - String of ValueOption list, separated by '/'
*/
private static String convertOptionIntoString(List<ValueOption> options) {
StringBuilder builder = new StringBuilder();
for(int i = 0; i < options.size(); i++) {
if(i > 0) {
builder.append("/");
}
builder.append(options.get(i).getValue().toString());
}
return builder.toString();
}
/**
* Encapsulate the validate result string into @ValidateResult.
* If the validateStr is not "OK", set the ValidateResult.status to false, and add @validateStr into
* ValidateResult.causes
*
* @param result
* - result set
* @param validateStr
* - validate result string
*/
private static void encapsulateResult(ValidateResult result, String validateStr) {
if(result != null) {
if(!VALIDATE_OK.equals(validateStr)) {
result.setStatus(false);
result.getCauses().add(validateStr);
}
}
}
/**
* Encapsulate validate result into total result.
* The status of total result will be false, if there is one false.
* The total result will contain all causes
*
* @param totalResult
* the total result
* @param result
* the current result
*/
private static void encapsulateResult(ValidateResult totalResult, ValidateResult result) {
totalResult.setStatus(totalResult.getStatus() && result.getStatus());
totalResult.getCauses().addAll(result.getCauses());
}
/**
* Get the method-style name of the property. (UPPER the first character:))
*
* @param fieldName
* the field name
* @return first character Upper style
*/
private static String getMethodName(String fieldName) {
return StringUtils.capitalize(fieldName);
}
}