/* * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.actor.worker; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Random; import ml.shifu.shifu.container.CaseScoreResult; import ml.shifu.shifu.container.ColumnScoreObject; import ml.shifu.shifu.container.ValueObject; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ModelConfig; import ml.shifu.shifu.exception.ShifuErrorCode; import ml.shifu.shifu.exception.ShifuException; import ml.shifu.shifu.message.ColumnScoreMessage; import ml.shifu.shifu.message.RunModelResultMessage; import ml.shifu.shifu.message.StatsPartRawDataMessage; import ml.shifu.shifu.message.StatsValueObjectMessage; import ml.shifu.shifu.util.CommonUtils; import org.apache.commons.lang.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import akka.actor.ActorRef; /** * DataPrepareWorker class convert data into all kinds of format. * StatsPartRawDataMessage - convert row-based data into column-based training data for calculating stats * RunModelResultMessage - convert model-result from row-based to column-based * NormPartRawDataMessage - filter data for normalization */ public class DataPrepareWorker extends AbstractWorkerActor { private static Logger log = LoggerFactory.getLogger(DataPrepareWorker.class); private Random random = new Random(System.currentTimeMillis()); private Map<Integer, ActorRef> columnNumToActorMap; private String[] trainDataHeader; private int weightedColumnNum = -1; public DataPrepareWorker(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, ActorRef parentActorRef, ActorRef nextActorRef) throws IOException { super(modelConfig, columnConfigList, parentActorRef, nextActorRef); trainDataHeader = CommonUtils.getFinalHeaders(modelConfig); } public DataPrepareWorker(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, ActorRef parentActorRef, Map<Integer, ActorRef> columnNumToActorMap) throws IOException { this(modelConfig, columnConfigList, parentActorRef, (ActorRef) null); this.columnNumToActorMap = columnNumToActorMap; if(!StringUtils.isEmpty(this.modelConfig.getDataSet().getWeightColumnName())) { String weightColumnName = this.modelConfig.getDataSet().getWeightColumnName(); for(int i = 0; i < this.columnConfigList.size(); i++) { ColumnConfig config = this.columnConfigList.get(i); if(config.getColumnName().equals(weightColumnName)) { this.weightedColumnNum = i; break; } } } } /* * (non-Javadoc) * * @see akka.actor.UntypedActor#onReceive(java.lang.Object) */ @Override public void handleMsg(Object message) { if(message instanceof StatsPartRawDataMessage) { StatsPartRawDataMessage partData = (StatsPartRawDataMessage) message; Map<Integer, List<ValueObject>> columnVoListMap = buildColumnVoListMap(partData.getRawDataList().size()); DataPrepareStatsResult rt = convertRawDataIntoValueObject(partData.getRawDataList(), columnVoListMap); int totalMsgCnt = partData.getTotalMsgCnt(); for(Map.Entry<Integer, List<ValueObject>> entry: columnVoListMap.entrySet()) { Integer columnNum = entry.getKey(); log.info("send {} with {} value object", columnNum, entry.getValue().size()); columnNumToActorMap.get(columnNum).tell( new StatsValueObjectMessage(totalMsgCnt, columnNum, entry.getValue(), rt.getMissingMap() .containsKey(columnNum) ? rt.getMissingMap().get(columnNum) : 0, rt.getTotal()), getSelf()); } } else if(message instanceof RunModelResultMessage) { RunModelResultMessage msg = (RunModelResultMessage) message; Map<Integer, List<ColumnScoreObject>> columnScoreListMap = buildColumnScoreListMap(); convertModelResultIntoColScore(msg.getScoreResultList(), columnScoreListMap); int totalMsgCnt = msg.getTotalStreamCnt(); for(Entry<Integer, List<ColumnScoreObject>> column: columnScoreListMap.entrySet()) { columnNumToActorMap.get(column.getKey()).tell( new ColumnScoreMessage(totalMsgCnt, column.getKey(), column.getValue()), getSelf()); } } else { unhandled(message); } } /* * Create the Map<ColumnID, List<ValueObject>> to prepare the data for calculating stats of each column * If the input message doesn't contain any data, the actor won't send message into next-actor who is waiting the * message. * Under this situation, it will cause AKKA to wait infinitely. * * @return initialed map for final candidate columns */ private Map<Integer, List<ValueObject>> buildColumnVoListMap(int capacity) { Map<Integer, List<ValueObject>> columnVoListMap = new HashMap<Integer, List<ValueObject>>(); for(ColumnConfig columnConfig: columnConfigList) { if(columnConfig.isCandidate()) { columnVoListMap.put(columnConfig.getColumnNum(), new ArrayList<ValueObject>(capacity)); } } return columnVoListMap; } /* * Create the Map<ColumnID, List<ColumnScore>> to prepare the data for calculating average score of each column * If the input message doesn't contain any data, the actor won't send message into next-actor who is waiting the * message. * Under this situation, it will cause AKKA to wait infinitely. * * @return initialed map for final select columns */ private Map<Integer, List<ColumnScoreObject>> buildColumnScoreListMap() { Map<Integer, List<ColumnScoreObject>> columnScoreListMap = new HashMap<Integer, List<ColumnScoreObject>>(); for(ColumnConfig columnConfig: columnConfigList) { if(columnConfig.isCandidate() && columnConfig.isFinalSelect()) { columnScoreListMap.put(columnConfig.getColumnNum(), new ArrayList<ColumnScoreObject>()); } } return columnScoreListMap; } /* * Convert raw data into @ValueObject for calculating stats * * @param rawDataList * - raw data for training * @param columnVoListMap * <column-id --> @ValueObject list> * @throws ShifuException * if the data field length is not equal header length */ private DataPrepareStatsResult convertRawDataIntoValueObject(List<String> rawDataList, Map<Integer, List<ValueObject>> columnVoListMap) throws ShifuException { double sampleRate = modelConfig.getBinningSampleRate(); long total = 0l; Map<Integer, Long> missingMap = new HashMap<Integer, Long>(); for(String line: rawDataList) { total++; String[] raw = CommonUtils.split(line, modelConfig.getDataSetDelimiter()); if(raw.length != columnConfigList.size()) { log.error("Expected Columns: " + columnConfigList.size() + ", but got: " + raw.length); throw new ShifuException(ShifuErrorCode.ERROR_NO_EQUAL_COLCONFIG); } String tag = CommonUtils.trimTag(raw[targetColumnNum]); if(modelConfig.isBinningSampleNegOnly()) { if(modelConfig.getNegTags().contains(tag) && random.nextDouble() > sampleRate) { continue; } } else { if(random.nextDouble() > sampleRate) { continue; } } for(int i = 0; i < raw.length; i++) { if(!columnNumToActorMap.containsKey(i)) { // ignore non-used columns continue; } ValueObject vo = new ValueObject(); if(i >= columnConfigList.size()) { log.error("The input size is longer than expected, need to check your data"); continue; } ColumnConfig config = columnConfigList.get(i); if(config.isNumerical()) { // NUMERICAL try { vo.setValue(Double.valueOf(raw[i].trim())); vo.setRaw(null); } catch (Exception e) { log.debug("Column " + config.getColumnNum() + ": " + config.getColumnName() + " is expected to be NUMERICAL, however received: " + raw[i]); incMap(i, missingMap); continue; } } else if(config.isCategorical()) { // CATEGORICAL if(raw[i] == null || StringUtils.isEmpty(raw[i]) || modelConfig.getDataSet().getMissingOrInvalidValues() .contains(raw[i].toLowerCase().trim())) { incMap(i, missingMap); } vo.setRaw(raw[i].trim()); vo.setValue(null); } else { // AUTO TYPE try { vo.setValue(Double.valueOf(raw[i])); vo.setRaw(null); } catch (Exception e) { incMap(i, missingMap); vo.setRaw(raw[i]); vo.setValue(null); } } if(this.weightedColumnNum != -1) { try { vo.setWeight(Double.valueOf(raw[weightedColumnNum])); } catch (NumberFormatException e) { vo.setWeight(1.0); } vo.setWeight(1.0); } vo.setTag(tag); List<ValueObject> voList = columnVoListMap.get(i); if(voList == null) { voList = new ArrayList<ValueObject>(); columnVoListMap.put(i, voList); } voList.add(vo); } } DataPrepareStatsResult rt = new DataPrepareStatsResult(total, missingMap); return rt; } private void incMap(int index, Map<Integer, Long> mapping) { Long count = mapping.get(index); if(count == null) { mapping.put(index, Long.valueOf(1)); } else { mapping.put(index, count + 1); } } public static class DataPrepareStatsResult { public DataPrepareStatsResult(long total, Map<Integer, Long> missingMap) { this.total = total; this.missingMap = missingMap; } private long total; private Map<Integer, Long> missingMap; public long getTotal() { return total; } public void setTotal(long total) { this.total = total; } public Map<Integer, Long> getMissingMap() { return missingMap; } public void setMissingMap(Map<Integer, Long> missingMap) { this.missingMap = missingMap; } } /* * Convert model result data into column-based * * @param evalDataList * evaluation result list * @param columnScoreListMap * (column-id, List<ColumnScoreObject>) */ private void convertModelResultIntoColScore(List<CaseScoreResult> scoreResultList, Map<Integer, List<ColumnScoreObject>> columnScoreListMap) { for(CaseScoreResult scoreResult: scoreResultList) { Map<String, String> rawDataMap = CommonUtils.convertDataIntoMap(scoreResult.getInputData(), super.modelConfig.getDataSetDelimiter(), this.trainDataHeader); for(ColumnConfig config: columnConfigList) { if(config.isFinalSelect()) { ColumnScoreObject columnScore = new ColumnScoreObject(config.getColumnNum(), rawDataMap.get(config .getColumnName())); columnScore.setScores(scoreResult.getScores()); columnScore.setMaxScore(scoreResult.getMaxScore()); columnScore.setMinScore(scoreResult.getMinScore()); columnScore.setAvgScore(scoreResult.getAvgScore()); List<ColumnScoreObject> csList = columnScoreListMap.get(config.getColumnNum()); if(csList == null) { csList = new ArrayList<ColumnScoreObject>(); columnScoreListMap.put(config.getColumnNum(), csList); } csList.add(columnScore); } } } } }