/*
* Copyright [2012-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.actor.worker;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import ml.shifu.shifu.container.WeightAmplifier;
import ml.shifu.shifu.container.obj.ColumnConfig;
import ml.shifu.shifu.container.obj.ModelConfig;
import ml.shifu.shifu.core.DataSampler;
import ml.shifu.shifu.core.Normalizer;
import ml.shifu.shifu.message.NormPartRawDataMessage;
import ml.shifu.shifu.message.NormResultDataMessage;
import ml.shifu.shifu.util.CommonUtils;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.jexl2.Expression;
import org.apache.commons.jexl2.JexlContext;
import org.apache.commons.jexl2.JexlEngine;
import org.apache.commons.jexl2.MapContext;
import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import akka.actor.ActorRef;
/**
* DataNormalizeWorker class is to normalize the train data
* Notice, the last field of normalized data is the weight of the training data.
* The weight is set in @ModelConfig.normalize.weightAmplifier. It could be some column
*/
public class DataNormalizeWorker extends AbstractWorkerActor {
private static Logger log = LoggerFactory.getLogger(DataNormalizeWorker.class);
private Expression weightExpr;
public DataNormalizeWorker(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, ActorRef parentActorRef,
ActorRef nextActorRef) {
super(modelConfig, columnConfigList, parentActorRef, nextActorRef);
weightExpr = createExpression(modelConfig.getWeightColumnName());
}
/*
* (non-Javadoc)
*
* @see akka.actor.UntypedActor#onReceive(java.lang.Object)
*/
@Override
public void handleMsg(Object message) {
if(message instanceof NormPartRawDataMessage) {
NormPartRawDataMessage msg = (NormPartRawDataMessage) message;
List<String> rawDataList = msg.getRawDataList();
int targetMsgCnt = msg.getTotalMsgCnt();
List<List<Double>> normalizedDataList = normalizeData(rawDataList);
nextActorRef.tell(new NormResultDataMessage(targetMsgCnt, rawDataList, normalizedDataList), this.getSelf());
} else {
unhandled(message);
}
}
/*
* Normalize the list training data from List<String> to List<Double>
*
* @param rawDataList
* @return the data after normalization
*/
private List<List<Double>> normalizeData(List<String> rawDataList) {
List<List<Double>> normalizedDataList = new ArrayList<List<Double>>();
for(String rawInput: rawDataList) {
String[] rf = CommonUtils.split(rawInput, modelConfig.getDataSetDelimiter());
List<Double> normRecord = normalizeRecord(rf);
if(CollectionUtils.isNotEmpty(normRecord)) {
normalizedDataList.add(normRecord);
}
}
return normalizedDataList;
}
/**
* Normalize the training data record
*
* @param rfs
* - record fields
* @return the data after normalization
*/
private List<Double> normalizeRecord(String[] rfs) {
List<Double> retDouList = new ArrayList<Double>();
if(rfs == null || rfs.length == 0) {
return null;
}
String tag = CommonUtils.trimTag(rfs[this.targetColumnNum]);
boolean isNotSampled = DataSampler.isNotSampled(modelConfig.getPosTags(), modelConfig.getNegTags(),
modelConfig.getNormalizeSampleRate(), modelConfig.isNormalizeSampleNegOnly(), tag);
if(isNotSampled) {
return null;
}
JexlContext jc = new MapContext();
Double cutoff = modelConfig.getNormalizeStdDevCutOff();
for(int i = 0; i < rfs.length; i++) {
ColumnConfig config = columnConfigList.get(i);
if(weightExpr != null) {
jc.set(config.getColumnName(), rfs[i]);
}
if(this.targetColumnNum == i) {
if(modelConfig.getPosTags().contains(tag)) {
retDouList.add(Double.valueOf(1));
} else if(modelConfig.getNegTags().contains(tag)) {
retDouList.add(Double.valueOf(0));
} else {
log.error("Invalid data! The target value is not listed - " + tag);
// Return null to skip such record.
return null;
}
} else if(!CommonUtils.isGoodCandidate(config)) {
retDouList.add(null);
} else {
String val = (rfs[i] == null) ? "" : rfs[i];
retDouList.add(Normalizer.normalize(config, val, cutoff, modelConfig.getNormalizeType()));
}
}
double weight = 1.0d;
if(weightExpr != null) {
Object result = weightExpr.evaluate(jc);
if(result instanceof Integer) {
weight = ((Integer) result).doubleValue();
} else if(result instanceof Double) {
weight = ((Double) result).doubleValue();
} else if(result instanceof String) {
// add to parse String data
try {
weight = Double.parseDouble((String) result);
} catch (NumberFormatException e) {
// Not a number, use default
if(System.currentTimeMillis() % 100 == 0) {
log.warn("Weight column type is String and value cannot be parsed with {}, use default 1.0d.",
result);
}
weight = 1.0d;
}
}
}
retDouList.add(weight);
return retDouList;
}
/*
* Create expressions for multi weight settings
*/
protected Map<Expression, Double> createExpressionMap(List<WeightAmplifier> weightExprList) {
Map<Expression, Double> ewMap = new HashMap<Expression, Double>();
if(CollectionUtils.isNotEmpty(weightExprList)) {
JexlEngine jexl = new JexlEngine();
for(WeightAmplifier we: weightExprList) {
ewMap.put(jexl.createExpression(we.getTargetExpression()), we.getTargetWeight());
}
}
return ewMap;
}
/*
* Create the expression for weight setting
*/
private Expression createExpression(String weightAmplifier) {
if(StringUtils.isNotBlank(weightAmplifier)) {
JexlEngine jexl = new JexlEngine();
return jexl.createExpression(weightAmplifier);
}
return null;
}
}