/*
* Copyright [2012-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.actor.worker;
import akka.actor.ActorRef;
import ml.shifu.shifu.container.CaseScoreResult;
import ml.shifu.shifu.container.obj.ColumnConfig;
import ml.shifu.shifu.container.obj.EvalConfig;
import ml.shifu.shifu.container.obj.ModelConfig;
import ml.shifu.shifu.container.obj.RawSourceData.SourceType;
import ml.shifu.shifu.core.ModelRunner;
import ml.shifu.shifu.core.model.ModelSpec;
import ml.shifu.shifu.message.RunModelDataMessage;
import ml.shifu.shifu.message.RunModelResultMessage;
import ml.shifu.shifu.util.CommonUtils;
import org.apache.commons.collections.CollectionUtils;
import org.encog.ml.BasicML;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
/**
* RunModelWorker class computes the score for input data
*/
public class RunModelWorker extends AbstractWorkerActor {
private ModelRunner modelRunner;
public RunModelWorker(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, EvalConfig evalConfig,
ActorRef parentActorRef, ActorRef nextActorRef) throws IOException {
super(modelConfig, columnConfigList, parentActorRef, nextActorRef);
List<BasicML> models = CommonUtils.loadBasicModels(modelConfig, evalConfig, SourceType.LOCAL);
String[] header = null;
String delimiter = null;
if( null == evalConfig
|| null == evalConfig.getDataSet().getHeaderPath()
|| null == evalConfig.getDataSet().getHeaderDelimiter()) {
header = CommonUtils.getFinalHeaders(modelConfig);
delimiter = modelConfig.getDataSetDelimiter();
} else {
header = CommonUtils.getFinalHeaders(evalConfig);
delimiter = evalConfig.getDataSet().getDataDelimiter();
}
modelRunner = new ModelRunner(modelConfig, columnConfigList, header, delimiter, models);
boolean gbtConvertToProp = ((evalConfig == null) ? false : evalConfig.getGbtConvertToProb());
SourceType sourceType = ((evalConfig == null) ?
modelConfig.getDataSet().getSource() : evalConfig.getDataSet().getSource());
List<ModelSpec> subModels = CommonUtils.loadSubModels(modelConfig, this.columnConfigList, evalConfig,
sourceType, gbtConvertToProp);
if(CollectionUtils.isNotEmpty(subModels)) {
for(ModelSpec modelSpec: subModels) {
this.modelRunner.addSubModels(modelSpec);
}
}
}
/*
* (non-Javadoc)
*
* @see akka.actor.UntypedActor#onReceive(java.lang.Object)
*/
@Override
public void handleMsg(Object message) {
if(message instanceof RunModelDataMessage) {
RunModelDataMessage msg = (RunModelDataMessage) message;
List<String> evalDataList = msg.getEvalDataList();
List<CaseScoreResult> scoreDataList = new ArrayList<CaseScoreResult>(evalDataList.size());
for(String evalData: evalDataList) {
CaseScoreResult scoreData = calculateModelScore(evalData);
if(scoreData != null) {
scoreData.setInputData(evalData);
scoreDataList.add(scoreData);
}
}
nextActorRef.tell(
new RunModelResultMessage(msg.getStreamId(), msg.getTotalStreamCnt(), msg.getMsgId(), msg
.isLastMsg(), scoreDataList), getSelf());
} else {
unhandled(message);
}
}
/**
* Call model runner to compute result score
*
* @param evalData
* - data to run model
* @return - the score result
*/
private CaseScoreResult calculateModelScore(String evalData) {
return modelRunner.compute(evalData);
}
}