/*
* Copyright [2012-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.actor.worker;
import akka.actor.ActorRef;
import ml.shifu.shifu.container.CaseScoreResult;
import ml.shifu.shifu.container.obj.ColumnConfig;
import ml.shifu.shifu.container.obj.EvalConfig;
import ml.shifu.shifu.container.obj.ModelConfig;
import ml.shifu.shifu.container.obj.RawSourceData.SourceType;
import ml.shifu.shifu.core.model.ModelSpec;
import ml.shifu.shifu.fs.PathFinder;
import ml.shifu.shifu.fs.ShifuFileUtils;
import ml.shifu.shifu.message.EvalResultMessage;
import ml.shifu.shifu.message.RunModelResultMessage;
import ml.shifu.shifu.util.CommonUtils;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.lang.StringUtils;
import org.encog.ml.BasicML;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.BufferedWriter;
import java.io.IOException;
import java.util.*;
import java.util.Map.Entry;
/**
* ScoreModelWorker class collect all the score for evaluation data and save
* them into file. If the evaluation data contains target column, it will also
* calculate the performance matrix.
*/
public class ScoreModelWorker extends AbstractWorkerActor {
private static Logger log = LoggerFactory.getLogger(ScoreModelWorker.class);
private EvalConfig evalConfig;
private String[] header;
private BufferedWriter scoreWriter;
// private Reasoner reasoner;
private int receivedStreamCnt;
private Map<Integer, StreamBulletin> resultMap;
private Map<String, Integer> subModelsCnt;
public ScoreModelWorker(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, ActorRef parentActorRef,
ActorRef nextActorRef, EvalConfig evalConfig) throws IOException {
super(modelConfig, columnConfigList, parentActorRef, nextActorRef);
this.evalConfig = evalConfig;
PathFinder pathFinder = new PathFinder(modelConfig);
// make sure local directory - evals/<EvalSetName> exists
ShifuFileUtils.createDirIfNotExists(pathFinder.getEvalSetPath(evalConfig), evalConfig.getDataSet().getSource());
// clear output - evals/<EvalSetName>/EvalScore at first,
// for it may be directory
ShifuFileUtils.deleteFile(pathFinder.getEvalScorePath(evalConfig), evalConfig.getDataSet().getSource());
// create score writer
scoreWriter = ShifuFileUtils.getWriter(pathFinder.getEvalScorePath(evalConfig), evalConfig.getDataSet()
.getSource());
// load the header for evaluation data
header = CommonUtils.getFinalHeaders(evalConfig);
receivedStreamCnt = 0;
resultMap = new HashMap<Integer, StreamBulletin>();
subModelsCnt = new TreeMap<String, Integer>();
List<ModelSpec> subModels = CommonUtils.loadSubModels(modelConfig, this.columnConfigList, evalConfig,
evalConfig.getDataSet().getSource(), evalConfig.getGbtConvertToProb());
if(CollectionUtils.isNotEmpty(subModels)) {
for(ModelSpec modelSpec: subModels) {
System.out.println("get sub model " + modelSpec.getModelName() + "|" + modelSpec.getModels().size());
subModelsCnt.put(modelSpec.getModelName(), modelSpec.getModels().size());
}
}
writeScoreHeader();
}
/*
* (non-Javadoc)
*
* @see akka.actor.UntypedActor#onReceive(java.lang.Object)
*/
@Override
public void handleMsg(Object message) throws IOException {
if(message instanceof RunModelResultMessage) {
log.debug("Received model score data for evaluation");
RunModelResultMessage msg = (RunModelResultMessage) message;
if(!resultMap.containsKey(msg.getStreamId())) {
receivedStreamCnt++;
resultMap.put(msg.getStreamId(), new StreamBulletin(msg.getStreamId()));
}
resultMap.get(msg.getStreamId()).receiveMsge(msg.getMsgId(), msg.isLastMsg());
List<CaseScoreResult> caseScoreResultList = msg.getScoreResultList();
StringBuilder buf = new StringBuilder();
for(CaseScoreResult csResult: caseScoreResultList) {
buf.setLength(0);
Map<String, String> rawDataMap = CommonUtils.convertDataIntoMap(csResult.getInputData(), evalConfig
.getDataSet().getDataDelimiter(), header);
// get the tag
String tag = CommonUtils.trimTag(rawDataMap.get(modelConfig.getTargetColumnName(evalConfig)));
buf.append(tag);
// append weight column value
if(StringUtils.isNotBlank(evalConfig.getDataSet().getWeightColumnName())) {
String metric = rawDataMap.get(evalConfig.getDataSet().getWeightColumnName());
buf.append("|" + StringUtils.trimToEmpty(metric));
} else {
buf.append("|" + "1.0");
}
if ( CollectionUtils.isNotEmpty(csResult.getScores()) ) {
addModelScoreData(buf, csResult);
}
Map<String, CaseScoreResult> subModelScores = csResult.getSubModelScores();
if ( MapUtils.isNotEmpty(subModelScores) ) {
Iterator<Map.Entry<String, CaseScoreResult>> iterator = subModelScores.entrySet().iterator();
while(iterator.hasNext()) {
Map.Entry<String, CaseScoreResult> entry = iterator.next();
CaseScoreResult subCs = entry.getValue();
addModelScoreData(buf, subCs);
}
}
// append meta data
List<String> metaColumns = evalConfig.getScoreMetaColumns(modelConfig);
if(CollectionUtils.isNotEmpty(metaColumns)) {
for(String columnName: metaColumns) {
String value = rawDataMap.get(columnName);
buf.append("|" + StringUtils.trimToEmpty(value));
}
}
scoreWriter.write(buf.toString() + "\n");
}
if(receivedStreamCnt == msg.getTotalStreamCnt() && hasAllMessageResult(resultMap)) {
log.info("Finish running scoring, the score file - {} is stored in {}.", new PathFinder(modelConfig)
.getEvalScorePath(evalConfig).toString(), evalConfig.getDataSet().getSource().name());
scoreWriter.close();
// only one message will be sent
nextActorRef.tell(new EvalResultMessage(1), this.getSelf());
}
} else {
unhandled(message);
}
}
private boolean hasAllMessageResult(Map<Integer, StreamBulletin> resultMsgMap) {
Iterator<Entry<Integer, StreamBulletin>> iterator = resultMsgMap.entrySet().iterator();
while(iterator.hasNext()) {
Entry<Integer, StreamBulletin> entry = iterator.next();
if(!entry.getValue().isMessageEnd()) {
return false;
}
}
return true;
}
private void addModelScoreData(StringBuilder buf, CaseScoreResult cs) {
buf.append("|" + cs.getAvgScore());
buf.append("|" + cs.getMaxScore());
buf.append("|" + cs.getMinScore());
buf.append("|" + cs.getMedianScore());
// score
for (Double score : cs.getScores()) {
buf.append("|" + score);
}
}
/**
* Write the file header for score file
*
* @throws IOException
* if any ip exception
*/
private void writeScoreHeader() throws IOException {
StringBuilder buf = new StringBuilder();
buf.append(modelConfig.getTargetColumnName(evalConfig) == null ? "tag" : modelConfig
.getTargetColumnName(evalConfig));
buf.append("|" + (StringUtils.isBlank(evalConfig.getDataSet().getWeightColumnName())
? "weight" : evalConfig.getDataSet().getWeightColumnName()));
List<BasicML> models = CommonUtils.loadBasicModels(modelConfig, evalConfig, SourceType.LOCAL);
if ( CollectionUtils.isNotEmpty(models) ) {
addModelScoreHeader(buf, models.size(), "");
}
if(MapUtils.isNotEmpty(this.subModelsCnt)) {
Iterator<Map.Entry<String, Integer>> iterator = this.subModelsCnt.entrySet().iterator();
while(iterator.hasNext()) {
Map.Entry<String, Integer> entry = iterator.next();
String modelName = entry.getKey();
Integer smCnt = entry.getValue();
if(smCnt > 0) {
addModelScoreHeader(buf, smCnt, modelName);
}
}
}
// append meta data
List<String> metaColumns = evalConfig.getAllMetaColumns(modelConfig);
if(CollectionUtils.isNotEmpty(metaColumns)) {
for(String columnName: metaColumns) {
buf.append("|" + columnName);
}
}
scoreWriter.write(buf.toString() + "\n");
}
private void addModelScoreHeader(StringBuilder buf, Integer modelCnt, String modelName) {
buf.append("|" + addModelNameAsNS(modelName, "mean"));
buf.append("|" + addModelNameAsNS(modelName, "max"));
buf.append("|" + addModelNameAsNS(modelName, "min"));
buf.append("|" + addModelNameAsNS(modelName, "median"));
for (int i = 0; i < modelCnt; i++) {
buf.append("|" + addModelNameAsNS(modelName, "model" + i));
}
}
private String addModelNameAsNS(String modelName, String scoreName) {
return (StringUtils.isBlank(modelName) ? scoreName : modelName + "::" + scoreName);
}
public static class StreamBulletin {
private int streamId;
private long targetSum;
private long totalSum;
private boolean hasLastMsg;
public StreamBulletin(int streamId) {
this.streamId = streamId;
this.totalSum = 0;
this.targetSum = 0;
this.hasLastMsg = false;
}
public void receiveMsge(int msgId, boolean isLastMsg) {
if(isLastMsg) {
hasLastMsg = true;
targetSum = msgId * (msgId + 1) / 2;
}
totalSum = totalSum + msgId;
}
public int getParallId() {
return this.streamId;
}
public boolean isMessageEnd() {
return hasLastMsg && (totalSum == targetSum);
}
}
}