/*
* Copyright [2012-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.container.obj;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import ml.shifu.shifu.core.alg.LogisticRegressionTrainer;
import ml.shifu.shifu.core.alg.SVMTrainer;
import ml.shifu.shifu.core.dtrain.CommonConstants;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
/**
* {@link ModelTrainConf} is train part in ModelConfig.json.
*/
@JsonIgnoreProperties(ignoreUnknown = true)
public class ModelTrainConf {
/**
* Different training algorithms supported in Shifu. SVM actuall is not implemented well. DT is replaced by RF and
* GBT.
*
* @author Zhang David (pengzhang@paypal.com)
*/
public static enum ALGORITHM {
NN, LR, SVM, DT, RF, GBT
}
/**
* Multiple classification algorithm. NATIVE is supported in NN/RF. ONEVSALL/ONEVSREST is by enabling multiple
* regerssion running.
*
* @author Zhang David (pengzhang@paypal.com)
*/
@JsonDeserialize(using = MultipleClassificationDeserializer.class)
public static enum MultipleClassification {
NATIVE, // means using NN regression or RF classification, not one vs all or one vs one
ONEVSALL, ONVVSREST, // the same as ONEVSALL
ONVVSONE; // ONEVSONE is not impl yet.
/*
* Get {@link MultipleClassification} by string, case can be ignored.
*/
public static MultipleClassification of(String strategy) {
for(MultipleClassification element: values()) {
if(element.toString().equalsIgnoreCase(strategy)) {
return element;
}
}
throw new IllegalArgumentException("cannot find such enum in MULTIPLE_CLASSIFICATION");
}
}
/**
* How many bagging jobs in training.
*/
private Integer baggingNum = Integer.valueOf(5);
// this is set default as true as bagging often with replacement sampleing.
/**
* Bagging sampling with replacement, this is only works well in NN. In RF, bagging sampling with replacement is
* enabled no matter true or false. In GBT, bagging sampling with replacement is disabled no matter true or false
*/
private Boolean baggingWithReplacement = Boolean.TRUE;
/**
* In each bagging job to do sampling according to this sample rate.
*/
private Double baggingSampleRate = Double.valueOf(1.0);
/**
* After bagging sampling, current rate of records is used to do validation.
*/
private Double validSetRate = Double.valueOf(0.2);
/**
* Only sample negative records out, this works with {@link #baggingSampleRate}.
*/
private Boolean sampleNegOnly = Boolean.FALSE;
/**
* If training is converged. 0 means not enabled early stop feature.
*/
private Double convergenceThreshold = Double.valueOf(0.0);
/**
* Iterations used in training.
*/
private Integer numTrainEpochs = Integer.valueOf(100);
/**
* For NN only, how many epochs training in one iteration.
*/
private Integer epochsPerIteration = Integer.valueOf(1);
/**
* Train data located on disk or not, this parameter is deprecated because of in NN/LR MemoryDiskList is used if not
* enough memory, disk will be automatically used. In GBDT/RF, because of data with prediction is changed in each
* tree, only memory list is supported.
*/
@Deprecated
private Boolean trainOnDisk = Boolean.FALSE;
/**
* If enabled by true, training data and validation data will be fixed in training even another job is started.
*/
private Boolean fixInitInput = Boolean.FALSE;
/**
* Only works in regression, if enabled by true, both positive and negative records will be sampled independent.
*/
private Boolean stratifiedSample = Boolean.FALSE;
/**
* If continue model training based on existing model in model path, this is like warm-start in scikit-learn.
*/
private Boolean isContinuous = Boolean.FALSE;
/**
* Only works in NN and do swapping training, validation data in differnent epochs.
*/
private Boolean isCrossOver = Boolean.FALSE;
/**
* How many threads in each worker, this will enable multiple threading running in workers.
*/
private Integer workerThreadCount = 4;
/**
* If enabled by a value in (1 - 20], cross validation will be enabled. Jobs will be started to train according to
* k-fold training data. Final average validation error will be printed in console.
*/
private Integer numKFold = -1;
/**
* Up sampling for positive tags, this is to solve class imbalance.
*/
private Double upSampleWeight = Double.valueOf(1d);
/**
* Algorithm: LR, NN, RF, GBT
*/
private String algorithm = "NN";
/**
* Model params for training like learning rate, tree depth ...
*/
private Map<String, Object> params;
/**
* Multiple classification method: NATIVE or ONEVSALL(ONEVSREST)
*/
private MultipleClassification multiClassifyMethod = MultipleClassification.NATIVE;
private Map<String, String> customPaths;
public ModelTrainConf() {
customPaths = new HashMap<String, String>(1);
/**
* Since most user won't use this function,
* hidden the custom paths for creating new model.
*/
/*
* customPaths.put(Constants.KEY_PRE_TRAIN_STATS_PATH, null);
* customPaths.put(Constants.KEY_SELECTED_RAW_DATA_PATH, null);
* customPaths.put(Constants.KEY_NORMALIZED_DATA_PATH, null);
* customPaths.put(Constants.KEY_TRAIN_SCORES_PATH, null);
* customPaths.put(Constants.KEY_BIN_AVG_SCORE_PATH, null);
*/
}
public Integer getBaggingNum() {
return baggingNum;
}
public void setBaggingNum(Integer baggingNum) {
this.baggingNum = baggingNum;
}
public Boolean getBaggingWithReplacement() {
return baggingWithReplacement;
}
public void setBaggingWithReplacement(Boolean baggingWithReplacement) {
this.baggingWithReplacement = baggingWithReplacement;
}
public Double getBaggingSampleRate() {
return baggingSampleRate;
}
public void setBaggingSampleRate(Double baggingSampleRate) {
this.baggingSampleRate = baggingSampleRate;
}
public Double getValidSetRate() {
return validSetRate;
}
public void setValidSetRate(Double validSetRate) {
this.validSetRate = validSetRate;
}
@JsonIgnore
public Boolean getTrainOnDisk() {
return trainOnDisk;
}
public void setTrainOnDisk(Boolean trainOnDisk) {
this.trainOnDisk = trainOnDisk;
}
@JsonIgnore
public Boolean getFixInitInput() {
return fixInitInput;
}
public void setFixInitInput(Boolean fixInitInput) {
this.fixInitInput = fixInitInput;
}
public Integer getNumTrainEpochs() {
return numTrainEpochs;
}
public void setNumTrainEpochs(Integer numTrainEpochs) {
this.numTrainEpochs = numTrainEpochs;
}
public String getAlgorithm() {
return algorithm;
}
public void setAlgorithm(String algorithm) {
this.algorithm = algorithm;
}
public Map<String, Object> getParams() {
return params;
}
public void setParams(Map<String, Object> params) {
this.params = params;
}
public Map<String, String> getCustomPaths() {
return customPaths;
}
public void setCustomPaths(Map<String, String> customPaths) {
this.customPaths = customPaths;
}
/**
* @return the epochsPerIteration
*/
@JsonIgnore
public Integer getEpochsPerIteration() {
return epochsPerIteration;
}
/**
* @param epochsPerIteration
* the epochsPerIteration to set
*/
@JsonProperty
public void setEpochsPerIteration(Integer epochsPerIteration) {
this.epochsPerIteration = epochsPerIteration;
}
/**
* As threshold is an optional setting, Use @{@link JsonIgnore} to ignore threshold when initially write
* out to ModelConfig.json.
*
* @return Convergence threshold.
*/
@JsonIgnore
public Double getConvergenceThreshold() {
return convergenceThreshold;
}
@JsonProperty
public void setConvergenceThreshold(Double convergenceThreshold) {
this.convergenceThreshold = convergenceThreshold;
}
@JsonIgnore
public Boolean getIsCrossOver() {
return isCrossOver;
}
/**
* @param isCrossOver
* the isCrossOver to set
*/
@JsonProperty
public void setIsCrossOver(Boolean isCrossOver) {
this.isCrossOver = isCrossOver;
}
/**
* @return the isContinuous
*/
public Boolean getIsContinuous() {
return isContinuous;
}
/**
* @param isContinuous
* the isContinuous to set
*/
public void setIsContinuous(Boolean isContinuous) {
this.isContinuous = isContinuous;
}
/**
* @return the workerThreadCount
*/
public Integer getWorkerThreadCount() {
return workerThreadCount;
}
/**
* @param workerThreadCount
* the workerThreadCount to set
*/
public void setWorkerThreadCount(Integer workerThreadCount) {
this.workerThreadCount = workerThreadCount;
}
/**
* @return the upSampleWeight
*/
@JsonIgnore
public Double getUpSampleWeight() {
return upSampleWeight;
}
/**
* @param upSampleWeight
* the upSampleWeight to set
*/
public void setUpSampleWeight(Double upSampleWeight) {
this.upSampleWeight = upSampleWeight;
}
/**
* @return the multiClassifyMethod
*/
@JsonIgnore
public MultipleClassification getMultiClassifyMethod() {
return multiClassifyMethod;
}
/**
* @param multiClassifyMethod
* the multiClassifyMethod to set
*/
@JsonProperty
public void setMultiClassifyMethod(MultipleClassification multiClassifyMethod) {
this.multiClassifyMethod = multiClassifyMethod;
}
@JsonIgnore
public boolean isOneVsAll() {
return this.multiClassifyMethod == MultipleClassification.ONEVSALL
|| this.multiClassifyMethod == MultipleClassification.ONVVSREST;
}
/**
* @return the sampleNegOnly
*/
@JsonIgnore
public Boolean getSampleNegOnly() {
return sampleNegOnly;
}
/**
* @param sampleNegOnly
* the sampleNegOnly to set
*/
@JsonProperty
public void setSampleNegOnly(Boolean sampleNegOnly) {
this.sampleNegOnly = sampleNegOnly;
}
/**
* @return the stratifiedSample
*/
@JsonIgnore
public Boolean getStratifiedSample() {
return stratifiedSample;
}
/**
* @param stratifiedSample
* the stratifiedSampling to set
*/
@JsonProperty
public void setStratifiedSample(Boolean stratifiedSample) {
this.stratifiedSample = stratifiedSample;
}
/**
* @return the numKFold
*/
@JsonIgnore
public Integer getNumKFold() {
return numKFold;
}
/**
* @param numKFold
* the numKFold to set
*/
@JsonProperty
public void setNumKFold(Integer numKFold) {
this.numKFold = numKFold;
}
@Override
public boolean equals(Object obj) {
if(obj == null || !(obj instanceof ModelTrainConf)) {
return false;
}
ModelTrainConf other = (ModelTrainConf) obj;
if(this == other) {
return true;
}
return this.algorithm.equals(other.getAlgorithm()) && this.baggingNum.equals(other.getBaggingNum())
&& this.getNumTrainEpochs().equals(other.getNumTrainEpochs())
&& this.validSetRate.equals(other.getValidSetRate());
}
@Override
public ModelTrainConf clone() {
ModelTrainConf other = new ModelTrainConf();
other.setAlgorithm(algorithm);
other.setBaggingNum(baggingNum);
other.setBaggingSampleRate(baggingSampleRate);
other.setConvergenceThreshold(convergenceThreshold);
if(customPaths != null) {
other.setCustomPaths(new HashMap<String, String>(customPaths));
}
other.setEpochsPerIteration(epochsPerIteration);
other.setFixInitInput(fixInitInput);
other.setIsContinuous(isContinuous);
other.setMultiClassifyMethod(multiClassifyMethod);
other.setNumTrainEpochs(numTrainEpochs);
other.setParams(new HashMap<String, Object>(params));
other.setTrainOnDisk(trainOnDisk);
other.setUpSampleWeight(upSampleWeight);
other.setValidSetRate(validSetRate);
other.setWorkerThreadCount(workerThreadCount);
return other;
}
public static Map<String, Object> createParamsByAlg(ALGORITHM alg, ModelTrainConf trainConf) {
Map<String, Object> params = new HashMap<String, Object>();
if(ALGORITHM.NN.equals(alg)) {
params.put(CommonConstants.PROPAGATION, "R");
params.put(CommonConstants.LEARNING_RATE, 0.1);
params.put(CommonConstants.NUM_HIDDEN_LAYERS, 1);
List<Integer> nodes = new ArrayList<Integer>();
nodes.add(50);
params.put(CommonConstants.NUM_HIDDEN_NODES, nodes);
List<String> func = new ArrayList<String>();
func.add("tanh");
params.put(CommonConstants.ACTIVATION_FUNC, func);
params.put("RegularizedConstant", 0.0);
} else if(ALGORITHM.SVM.equals(alg)) {
params.put(SVMTrainer.SVM_KERNEL, "linear");
params.put(SVMTrainer.SVM_GAMMA, 1.0);
params.put(SVMTrainer.SVM_CONST, 1.0);
} else if(ALGORITHM.RF.equals(alg)) {
params.put("TreeNum", "10");
params.put("FeatureSubsetStrategy", "TWOTHIRDS");
params.put("MaxDepth", 10);
params.put("MinInstancesPerNode", 1);
params.put("MinInfoGain", 0.0);
params.put("Impurity", "variance");
params.put("Loss", "squared");
trainConf.setNumTrainEpochs(1000);
} else if(ALGORITHM.GBT.equals(alg)) {
params.put("TreeNum", "100");
params.put("FeatureSubsetStrategy", "TWOTHIRDS");
params.put("MaxDepth", 7);
params.put("MinInstancesPerNode", 5);
params.put("MinInfoGain", 0.0);
params.put("DropoutRate", 0.0);
params.put("Impurity", "variance");
params.put(CommonConstants.LEARNING_RATE, 0.05);
params.put("Loss", "squared");
trainConf.setNumTrainEpochs(1000);
} else if(ALGORITHM.LR.equals(alg)) {
params.put(LogisticRegressionTrainer.LEARNING_RATE, 0.1);
params.put("RegularizedConstant", 0.0);
params.put("L1orL2", "NONE");
}
return params;
}
}