/* * Copyright [2013-2016] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.init; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; import ml.shifu.shifu.util.updater.ColumnConfigUpdater; import org.apache.commons.lang.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import ml.shifu.common.Step; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ModelConfig; import ml.shifu.shifu.core.validator.ModelInspector.ModelStep; import ml.shifu.shifu.util.CommonUtils; /** * Init Step in Shifu to call ColumnConfig initialization. * * @author Zhang David (pengzhang@paypal.com) */ public class InitStep extends Step<List<ColumnConfig>> { private final static Logger LOG = LoggerFactory.getLogger(InitStep.class); public InitStep(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, Map<String, Object> otherConfigs) { super(ModelStep.INIT, modelConfig, columnConfigList, otherConfigs); } /* * (non-Javadoc) * * @see ml.shifu.common.Step#process() */ @Override public List<ColumnConfig> process() throws IOException { // 1. init from header files initColumnConfigList(); return this.columnConfigList; } private int initColumnConfigList() throws IOException { String[] fields = null; boolean isSchemaProvided = true; if(StringUtils.isNotBlank(modelConfig.getHeaderPath())) { fields = CommonUtils.getHeaders(modelConfig.getHeaderPath(), modelConfig.getHeaderDelimiter(), modelConfig .getDataSet().getSource()); String[] dataInFirstLine = CommonUtils.takeFirstLine(modelConfig.getDataSetRawPath(), modelConfig.getDataSetDelimiter(), modelConfig.getDataSet().getSource()); if(fields.length != dataInFirstLine.length) { throw new IllegalArgumentException( "Header length and data length are not consistent, please check you header setting and data set setting."); } } else { fields = CommonUtils.takeFirstLine(modelConfig.getDataSetRawPath(), StringUtils.isBlank(modelConfig .getHeaderDelimiter()) ? modelConfig.getDataSetDelimiter() : modelConfig.getHeaderDelimiter(), modelConfig.getDataSet().getSource()); if(StringUtils.join(fields, "").contains(modelConfig.getTargetColumnName())) { // if first line contains target column name, we guess it is csv format and first line is header. isSchemaProvided = true; // first line of data meaning second line in data files excluding first header line String[] dataInFirstLine = CommonUtils.takeFirstTwoLines(modelConfig.getDataSetRawPath(), StringUtils.isBlank(modelConfig.getHeaderDelimiter()) ? modelConfig.getDataSetDelimiter() : modelConfig.getHeaderDelimiter(), modelConfig.getDataSet().getSource())[1]; if(dataInFirstLine != null && fields.length != dataInFirstLine.length) { throw new IllegalArgumentException( "Header length and data length are not consistent, please check you header setting and data set setting."); } LOG.warn("No header path is provided, we will try to read first line and detect schema."); LOG.warn("Schema in ColumnConfig.json are named as first line of data set path."); } else { isSchemaProvided = false; LOG.warn("No header path is provided, we will try to read first line and detect schema."); LOG.warn("Schema in ColumnConfig.json are named as index 0, 1, 2, 3 ..."); LOG.warn("Please make sure weight column and tag column are also taking index as name."); } } columnConfigList = new ArrayList<ColumnConfig>(); for(int i = 0; i < fields.length; i++) { ColumnConfig config = new ColumnConfig(); config.setColumnNum(i); if(isSchemaProvided) { config.setColumnName(CommonUtils.getRelativePigHeaderColumnName(fields[i])); } else { config.setColumnName(i + ""); } columnConfigList.add(config); } ColumnConfigUpdater.updateColumnConfigFlags(modelConfig, columnConfigList, ModelStep.INIT); boolean hasTarget = false; for(ColumnConfig config: columnConfigList) { if(config.isTarget()) { hasTarget = true; } } if(!hasTarget) { LOG.error("Target is not valid: " + modelConfig.getTargetColumnName()); LOG.error("Please check your header file {} and your header delimiter {}", modelConfig.getHeaderPath(), modelConfig.getHeaderDelimiter()); return 1; } return 0; } }