/* * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.dvarsel; import java.io.IOException; import java.lang.reflect.InvocationTargetException; import java.util.ArrayList; import java.util.List; import java.util.Properties; import ml.shifu.guagua.hadoop.io.GuaguaLineRecordReader; import ml.shifu.guagua.hadoop.io.GuaguaWritableAdapter; import ml.shifu.guagua.io.GuaguaFileSplit; import ml.shifu.guagua.worker.AbstractWorkerComputable; import ml.shifu.guagua.worker.WorkerContext; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ModelConfig; import ml.shifu.shifu.container.obj.RawSourceData; import ml.shifu.shifu.core.DataPurifier; import ml.shifu.shifu.core.Normalizer; import ml.shifu.shifu.core.dtrain.CommonConstants; import ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet; import ml.shifu.shifu.core.dvarsel.dataset.TrainingRecord; import ml.shifu.shifu.util.CommonUtils; import ml.shifu.shifu.util.Constants; import org.apache.commons.lang.StringUtils; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Created on 11/24/2014. */ public class VarSelWorker extends AbstractWorkerComputable<VarSelMasterResult, VarSelWorkerResult, GuaguaWritableAdapter<LongWritable>, GuaguaWritableAdapter<Text>> { private static final Logger LOG = LoggerFactory.getLogger(VarSelWorker.class); private ModelConfig modelConfig; private List<ColumnConfig> columnConfigList; private AbstractWorkerConductor workerConductor; private long count = 0; private int inputNodeCount; private int outputNodeCount; private DataPurifier dataPurifier; private int targetColumnId = -1; private int weightColumnId = -1; private TrainingDataSet trainingDataSet; private long posRecordCount = 0; private long totalRecordCount = 0; @Override public void initRecordReader(GuaguaFileSplit fileSplit) throws IOException { this.setRecordReader(new GuaguaLineRecordReader()); this.getRecordReader().initialize(fileSplit); } @Override public void init(WorkerContext<VarSelMasterResult, VarSelWorkerResult> workerContext) { Properties props = workerContext.getProps(); try { RawSourceData.SourceType sourceType = RawSourceData.SourceType.valueOf(props.getProperty( CommonConstants.MODELSET_SOURCE_TYPE, RawSourceData.SourceType.HDFS.toString())); this.modelConfig = CommonUtils.loadModelConfig(props.getProperty(CommonConstants.SHIFU_MODEL_CONFIG), sourceType); this.columnConfigList = CommonUtils.loadColumnConfigList( props.getProperty(CommonConstants.SHIFU_COLUMN_CONFIG), sourceType); String conductorClsName = props.getProperty(Constants.VAR_SEL_WORKER_CONDUCTOR); this.workerConductor = (AbstractWorkerConductor) Class.forName(conductorClsName) .getDeclaredConstructor(ModelConfig.class, List.class) .newInstance(this.modelConfig, this.columnConfigList); } catch (IOException e) { throw new RuntimeException("Fail to load ModelConfig or List<ColumnConfig>", e); } catch (ClassNotFoundException e) { throw new RuntimeException("Invalid Master Conductor class", e); } catch (InstantiationException e) { throw new RuntimeException("Fail to create instance", e); } catch (IllegalAccessException e) { throw new RuntimeException("Illegal access when creating instance", e); } catch (NoSuchMethodException e) { throw new RuntimeException("Fail to call method when creating instance", e); } catch (InvocationTargetException e) { throw new RuntimeException("Fail to invoke when creating instance", e); } List<Integer> normalizedColumnIdList = this.getNormalizedColumnIdList(); this.inputNodeCount = normalizedColumnIdList.size(); this.outputNodeCount = this.getTargetColumnCount(); trainingDataSet = new TrainingDataSet(normalizedColumnIdList); try { dataPurifier = new DataPurifier(modelConfig); } catch (IOException e) { throw new RuntimeException("Fail to create DataPurifier", e); } this.targetColumnId = CommonUtils.getTargetColumnNum(this.columnConfigList); if(StringUtils.isNotBlank(modelConfig.getWeightColumnName())) { for(ColumnConfig columnConfig: columnConfigList) { if(columnConfig.getColumnName().equalsIgnoreCase(modelConfig.getWeightColumnName().trim())) { this.weightColumnId = columnConfig.getColumnNum(); break; } } } } @Override public VarSelWorkerResult doCompute(WorkerContext<VarSelMasterResult, VarSelWorkerResult> workerContext) { if(!workerConductor.isInitialized()) { LOG.info("There are {} records in current worker, with {} positive records.", totalRecordCount, posRecordCount); workerConductor.retainData(trainingDataSet); } VarSelMasterResult masterResult = workerContext.getLastMasterResult(); if(masterResult == null) { // no working set, wait master to send the working set return workerConductor.getDefaultWorkerResult(); } if(masterResult.isHalt()) { // finish variable selection, stop working return null; } LOG.info("Get result from master, the base seed count is - {}", masterResult.getSeedList().size()); workerConductor.consumeMasterResult(masterResult); return workerConductor.generateVarSelResult(); } @Override public void load(GuaguaWritableAdapter<LongWritable> currentKey, GuaguaWritableAdapter<Text> currentValue, WorkerContext<VarSelMasterResult, VarSelWorkerResult> workerContext) { if((++this.count) % 100000 == 0) { LOG.info("Read {} records.", this.count); } String record = currentValue.getWritable().toString(); String[] fields = CommonUtils.split(record, this.modelConfig.getDataSetDelimiter()); String tag = CommonUtils.trimTag(fields[this.targetColumnId]); if(this.dataPurifier.isFilterOut(record) && isPosOrNegTag(this.modelConfig, tag)) { this.totalRecordCount++; if(this.modelConfig.getPosTags().contains(tag)) { this.posRecordCount++; } double[] inputs = new double[this.inputNodeCount]; double[] ideal = new double[this.outputNodeCount]; double significance = CommonConstants.DEFAULT_SIGNIFICANCE_VALUE; if(this.weightColumnId >= 0) { try { significance = Double.parseDouble(fields[this.weightColumnId]); } catch (Exception e) { // user may set wrong field, just used default. } } ideal[0] = (this.modelConfig.getPosTags().contains(tag) ? 1.0d : 0.0d); int i = 0; for(Integer columnId: this.trainingDataSet.getDataColumnIdList()) { inputs[i++] = Normalizer.normalize(columnConfigList.get(columnId), fields[columnId]); } trainingDataSet.addTrainingRecord(new TrainingRecord(inputs, ideal, significance)); } } private boolean isPosOrNegTag(ModelConfig config, String tag) { return config.getPosTags().contains(tag) || config.getNegTags().contains(tag); } private List<Integer> getNormalizedColumnIdList() { List<Integer> normalizedColumnIdList = new ArrayList<Integer>(); for(ColumnConfig config: columnConfigList) { if(CommonUtils.isGoodCandidate(config)) { normalizedColumnIdList.add(config.getColumnNum()); } } return normalizedColumnIdList; } private int getTargetColumnCount() { int targetCount = 0; for(ColumnConfig config: columnConfigList) { if(config.isTarget()) { targetCount++; } } return targetCount; } }