/* * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.actor.worker; import java.util.ArrayList; import java.util.LinkedList; import java.util.List; import java.util.Scanner; import ml.shifu.guagua.util.NumberFormatUtils; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ModelConfig; import ml.shifu.shifu.core.dtrain.CommonConstants; import ml.shifu.shifu.core.dtrain.DTrainUtils; import ml.shifu.shifu.message.NormPartRawDataMessage; import ml.shifu.shifu.message.RunModelDataMessage; import ml.shifu.shifu.message.ScanEvalDataMessage; import ml.shifu.shifu.message.ScanNormInputDataMessage; import ml.shifu.shifu.message.ScanStatsRawDataMessage; import ml.shifu.shifu.message.ScanTrainDataMessage; import ml.shifu.shifu.message.StatsPartRawDataMessage; import ml.shifu.shifu.message.TrainPartDataMessage; import ml.shifu.shifu.util.CommonUtils; import ml.shifu.shifu.util.Environment; import org.encog.ml.data.MLDataPair; import org.encog.ml.data.basic.BasicMLData; import org.encog.ml.data.basic.BasicMLDataPair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import akka.actor.ActorRef; import com.google.common.base.Splitter; /** * DataLoadWorker class is used to load data from all kinds of source. * Its input is data scanner. The output are usually List. */ public class DataLoadWorker extends AbstractWorkerActor { private static Logger log = LoggerFactory.getLogger(DataLoadWorker.class); /** * Default splitter used to split input record. Use one instance to prevent more news in Splitter.on. */ private static final Splitter DEFAULT_SPLITTER = Splitter.on(CommonConstants.DEFAULT_COLUMN_SEPARATOR); /** * Basic input node count for NN model */ private int inputNodeCount; /** * {@link #candidateCount} is used to check if no variable is selected. If {@link #inputNodeCount} equals * {@link #candidateCount}, which means no column is selected or all columns are selected. */ private int candidateCount; public DataLoadWorker(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, ActorRef parentActorRef, ActorRef nextActorRef) { super(modelConfig, columnConfigList, parentActorRef, nextActorRef); int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(this.columnConfigList); this.inputNodeCount = inputOutputIndex[0] == 0 ? inputOutputIndex[2] : inputOutputIndex[0]; this.candidateCount = inputOutputIndex[2]; } /* * (non-Javadoc) * * @see akka.actor.UntypedActor#onReceive(java.lang.Object) */ @Override public void handleMsg(Object message) { if(message instanceof ScanStatsRawDataMessage) { log.info("DataLoaderActor Starting ..."); ScanStatsRawDataMessage msg = (ScanStatsRawDataMessage) message; Scanner scanner = msg.getScanner(); int totalMsgCnt = msg.getTotalMsgCnt(); List<String> rawDataList = readDataIntoList(scanner); log.info("DataLoaderActor Finished: Loaded " + rawDataList.size() + " Records."); nextActorRef.tell(new StatsPartRawDataMessage(totalMsgCnt, rawDataList), getSelf()); } else if(message instanceof ScanNormInputDataMessage) { log.info("DataLoaderActor Starting ..."); ScanNormInputDataMessage msg = (ScanNormInputDataMessage) message; Scanner scanner = msg.getScanner(); int totalMsgCnt = msg.getTotalMsgCnt(); List<String> rawDataList = readDataIntoList(scanner); log.info("DataLoaderActor Finished: Loaded " + rawDataList.size() + " Records."); nextActorRef.tell(new NormPartRawDataMessage(totalMsgCnt, rawDataList), getSelf()); } else if(message instanceof ScanTrainDataMessage) { ScanTrainDataMessage msg = (ScanTrainDataMessage) message; Scanner scanner = msg.getScanner(); int totalMsgCnt = msg.getTotalMsgCnt(); List<MLDataPair> mlDataPairList = readTrainingData(scanner, msg.isDryRun()); log.info("DataLoaderActor Finished: Loaded " + mlDataPairList.size() + " Records for Training."); nextActorRef.tell(new TrainPartDataMessage(totalMsgCnt, msg.isDryRun(), mlDataPairList), getSelf()); } else if(message instanceof ScanEvalDataMessage) { log.info("DataLoaderActor Starting ..."); ScanEvalDataMessage msg = (ScanEvalDataMessage) message; Scanner scanner = msg.getScanner(); int streamId = msg.getStreamId(); int totalStreamCnt = msg.getTotalStreamCnt(); splitDataIntoMultiMessages(streamId, totalStreamCnt, scanner, Environment.getInt(Environment.RECORD_CNT_PER_MESSAGE, 100000)); /* * List<String> evalDataList = readDataIntoList(scanner); * * log.info("DataLoaderActor Finished: Loaded " + evalDataList.size() + " Records."); * nextActorRef.tell( new RunModelDataMessage(totalMsgCnt, evalDataList), getSelf()); */ } else { unhandled(message); } } private long splitDataIntoMultiMessages(int streamId, int totalStreamCnt, Scanner scanner, int recordCntPerMsg) { long recordCnt = 0; int msgId = 0; List<String> rawDataList = new LinkedList<String>(); while(scanner.hasNextLine()) { String raw = scanner.nextLine(); recordCnt++; rawDataList.add(raw); if(recordCnt % recordCntPerMsg == 0) { log.info("Read " + recordCnt + " Records."); nextActorRef.tell(new RunModelDataMessage(streamId, totalStreamCnt, (msgId++), false, rawDataList), getSelf()); rawDataList = new LinkedList<String>(); } } log.info("Totally read " + recordCnt + " Records."); // anyhow, sent the last message to let next actor know - it's done nextActorRef.tell(new RunModelDataMessage(streamId, totalStreamCnt, (msgId++), true, rawDataList), getSelf()); return recordCnt; } /** * Read data into String list * * @param scanner * - input partition * @return list of data */ public List<String> readDataIntoList(Scanner scanner) { List<String> rawDataList = new LinkedList<String>(); int cntTotal = 0; while(scanner.hasNextLine()) { String raw = scanner.nextLine(); rawDataList.add(raw); cntTotal++; if(cntTotal % 100000 == 0) { log.info("Read " + cntTotal + " records."); } } log.info("Totally read " + cntTotal + " records."); return rawDataList; } /** * Read the normalized training data for model training * * @param scanner * - input partition * @param isDryRun * - is for test running? * @return List of data */ public List<MLDataPair> readTrainingData(Scanner scanner, boolean isDryRun) { List<MLDataPair> mlDataPairList = new ArrayList<MLDataPair>(); int numSelected = 0; for(ColumnConfig config: columnConfigList) { if(config.isFinalSelect()) { numSelected++; } } int cnt = 0; while(scanner.hasNextLine()) { if((cnt++) % 100000 == 0) { log.info("Read " + (cnt) + " Records."); } String line = scanner.nextLine(); if(isDryRun) { MLDataPair dummyPair = new BasicMLDataPair(new BasicMLData(new double[1]), new BasicMLData( new double[1])); mlDataPairList.add(dummyPair); continue; } // the normalized training data is separated by | by default double[] inputs = new double[numSelected]; double[] ideal = new double[1]; double significance = 0.0d; int index = 0, inputsIndex = 0, outputIndex = 0; for(String input: DEFAULT_SPLITTER.split(line.trim())) { double doubleValue = NumberFormatUtils.getDouble(input.trim(), 0.0d); if(index == this.columnConfigList.size()) { significance = NumberFormatUtils .getDouble(input.trim(), CommonConstants.DEFAULT_SIGNIFICANCE_VALUE); break; } else { ColumnConfig columnConfig = this.columnConfigList.get(index); if(columnConfig != null && columnConfig.isTarget()) { ideal[outputIndex++] = doubleValue; } else { if(this.inputNodeCount == this.candidateCount) { // all variables are not set final-select if(CommonUtils.isGoodCandidate(columnConfig)) { inputs[inputsIndex++] = doubleValue; } } else { // final select some variables if(columnConfig != null && !columnConfig.isMeta() && !columnConfig.isTarget() && columnConfig.isFinalSelect()) { inputs[inputsIndex++] = doubleValue; } } } } index++; } MLDataPair pair = new BasicMLDataPair(new BasicMLData(inputs), new BasicMLData(ideal)); pair.setSignificance(significance); mlDataPairList.add(pair); } return mlDataPairList; } }