/*
* Copyright [2013-2015] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.guagua;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import ml.shifu.guagua.hadoop.io.GuaguaInputSplit;
import ml.shifu.guagua.mapreduce.GuaguaInputFormat;
import ml.shifu.shifu.core.dtrain.CommonConstants;
import org.apache.commons.lang.ArrayUtils;
import org.apache.hadoop.fs.BlockLocation;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.PathFilter;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.JobContext;
import org.apache.hadoop.mapreduce.lib.input.FileSplit;
import org.apache.hadoop.mapreduce.lib.input.InvalidInputException;
import org.apache.hadoop.mapreduce.security.TokenCache;
import org.apache.hadoop.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class ShifuInputFormat extends GuaguaInputFormat {
private static final Logger LOG = LoggerFactory.getLogger(ShifuInputFormat.class);
private static final PathFilter hiddenFileFilter = new PathFilter() {
public boolean accept(Path p) {
String name = p.getName();
return !name.startsWith("_") && !name.startsWith(".");
}
};
private static class MultiPathFilter implements PathFilter {
private List<PathFilter> filters;
public MultiPathFilter(List<PathFilter> filters) {
this.filters = filters;
}
public boolean accept(Path path) {
for(PathFilter filter: filters) {
if(!filter.accept(path)) {
return false;
}
}
return true;
}
}
/**
* Splitter building logic including master setting, also includes combining input feature like Pig.
*/
@Override
public List<InputSplit> getSplits(JobContext job) throws IOException {
List<InputSplit> newSplits = super.getSplits(job);
String testDirs = job.getConfiguration().get("shifu.crossValidation.dir", "");
LOG.info("Validation dir is {};", testDirs);
if(org.apache.commons.lang.StringUtils.isNotBlank(testDirs)) {
this.addCrossValidationDataset(newSplits, job);
}
return newSplits;
}
private FileSplit getFileSplit(FileSystem fs, FileStatus file, long offset, long length) throws IOException {
BlockLocation[] blkLocations = fs.getFileBlockLocations(file, offset, length);
List<String> hosts = new ArrayList<String>();
for(BlockLocation location: blkLocations) {
hosts.addAll(Arrays.asList(location.getHosts()));
}
String[] shosts = new String[hosts.size()];
FileSplit fsp = new FileSplit(file.getPath(), offset, length, hosts.toArray(shosts));
return fsp;
}
protected List<List<FileSplit>> getCrossValidationSplits(JobContext job, int count) throws IOException {
LOG.debug("Split validation with count: {}", count);
List<FileStatus> files = listCrossValidationStatus(job);
List<FileSplit> current = new ArrayList<FileSplit>();
List<List<FileSplit>> validationList = new ArrayList<List<FileSplit>>();
long lengthSum = 0L;
for(FileStatus file: files) {
Path path = file.getPath();
if(isPigOrHadoopMetaFile(path)) {
continue;
}
lengthSum += file.getLen();
}
long size = lengthSum / count + 1;
long remaining = 0L;
for(FileStatus file: files) {
Path path = file.getPath();
if(isPigOrHadoopMetaFile(path)) {
continue;
}
FileSystem fs = path.getFileSystem(job.getConfiguration());
long offset = 0L;
long length = file.getLen();
if(length + remaining >= size) {
long cut = (size - remaining) >= length ? length : (size - remaining);
current.add(getFileSplit(fs, file, offset, cut));
offset = cut;
remaining = length - cut;
validationList.add(current);
current = new ArrayList<FileSplit>();
while(remaining >= size) {
current.add(getFileSplit(fs, file, offset, size));
validationList.add(current);
current = new ArrayList<FileSplit>();
remaining -= size;
offset += size;
}
if(remaining > 0) {
current.add(getFileSplit(fs, file, offset, remaining));
}
} else {
current.add(getFileSplit(fs, file, 0, length));
remaining += length;
}
}
if(current.size() > 0) {
validationList.add(current);
}
LOG.debug("Total # of validationList: {}", validationList.size());
return validationList;
}
protected void addCrossValidationDataset(List<InputSplit> trainingSplit, JobContext context) throws IOException {
List<InputSplit> trainingNoMaster = new ArrayList<InputSplit>();
for(InputSplit split: trainingSplit) {
GuaguaInputSplit guaguaInput = (GuaguaInputSplit) split;
if(guaguaInput.isMaster()) {
continue;
}
trainingNoMaster.add(guaguaInput);
}
List<List<FileSplit>> csSplits = this.getCrossValidationSplits(context, trainingNoMaster.size());
for(int i = 0; i < csSplits.size(); i++) {
List<FileSplit> oneInput = csSplits.get(i);
GuaguaInputSplit guaguaInput = (GuaguaInputSplit) trainingNoMaster.get(i);
int trainingSize = guaguaInput.getFileSplits().length;
FileSplit[] finalSplits = (FileSplit[]) ArrayUtils.addAll(guaguaInput.getFileSplits(),
oneInput.toArray(new FileSplit[0]));
guaguaInput.setFileSplits(finalSplits);
Boolean[] validationFlags = new Boolean[finalSplits.length];
for(int j = 0; j < finalSplits.length; j++) {
validationFlags[j] = j < trainingSize ? false : true;
}
guaguaInput.setExtensions(validationFlags);
}
LOG.info("Training input split size is: {}.", trainingSplit.size());
LOG.info("Validation input split size is {}.", csSplits.size());
}
@SuppressWarnings("deprecation")
protected List<FileStatus> listCrossValidationStatus(JobContext job) throws IOException {
List<FileStatus> result = new ArrayList<FileStatus>();
Path[] dirs = getInputPaths(job);
if(dirs.length == 0) {
throw new IOException("No input paths specified in job");
}
// get tokens for all the required FileSystems..
TokenCache.obtainTokensForNamenodes(job.getCredentials(), dirs, job.getConfiguration());
// Whether we need to recursive look into the directory structure
boolean recursive = job.getConfiguration().getBoolean("mapreduce.input.fileinputformat.input.dir.recursive",
false);
List<IOException> errors = new ArrayList<IOException>();
// creates a MultiPathFilter with the hiddenFileFilter and the
// user provided one (if any).
List<PathFilter> filters = new ArrayList<PathFilter>();
filters.add(hiddenFileFilter);
PathFilter jobFilter = getInputPathFilter(job);
if(jobFilter != null) {
filters.add(jobFilter);
}
PathFilter inputFilter = new MultiPathFilter(filters);
for(int i = 0; i < dirs.length; ++i) {
Path p = dirs[i];
FileSystem fs = p.getFileSystem(job.getConfiguration());
FileStatus[] matches = fs.globStatus(p, inputFilter);
if(matches == null) {
errors.add(new IOException("Input path does not exist: " + p));
} else if(matches.length == 0) {
errors.add(new IOException("Input Pattern " + p + " matches 0 files"));
} else {
for(FileStatus globStat: matches) {
if(globStat.isDir()) {
FileStatus[] fss = fs.listStatus(globStat.getPath());
for(FileStatus fileStatus: fss) {
if(inputFilter.accept(fileStatus.getPath())) {
if(recursive && fileStatus.isDir()) {
addInputPathRecursive(result, fs, fileStatus.getPath(), inputFilter);
} else {
result.add(fileStatus);
}
}
}
} else {
result.add(globStat);
}
}
}
}
if(!errors.isEmpty()) {
throw new InvalidInputException(errors);
}
LOG.info("Total validation paths to process : " + result.size());
return result;
}
@SuppressWarnings("deprecation")
private void addInputPathRecursive(List<FileStatus> result, FileSystem fs, Path path, PathFilter inputFilter)
throws IOException {
FileStatus[] fss = fs.listStatus(path);
for(FileStatus fileStatus: fss) {
if(inputFilter.accept(fileStatus.getPath())) {
if(fileStatus.isDir()) {
addInputPathRecursive(result, fs, fileStatus.getPath(), inputFilter);
} else {
result.add(fileStatus);
}
}
}
}
public static Path[] getInputPaths(JobContext context) {
String dirs = context.getConfiguration().get(CommonConstants.CROSS_VALIDATION_DIR, "");
LOG.info("crossValidation_dir:" + dirs);
String[] list = StringUtils.split(dirs);
Path[] result = new Path[list.length];
for(int i = 0; i < list.length; i++) {
result[i] = new Path(StringUtils.unEscapeString(list[i]));
}
return result;
}
}