/*
* Copyright [2012-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core;
import java.util.List;
import java.util.Random;
import java.util.Set;
import ml.shifu.shifu.util.CommonUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* DataSampler class
* - Output: column 0 is tag, following final select vars
*/
public class DataSampler {
private static Logger log = LoggerFactory.getLogger(DataSampler.class);
private static Random rd = new Random(System.currentTimeMillis());
/**
* check whether the data should be filtered out or not
* the data will be filtered out if
* - the target value is invalid
* - or target tag is not in positive tag list or negative tag list
* - or not be sampled
*
* @param targetColumnNum
* the target column
* @param posTags
* posTags
* @param negTags
* negTags
* @param data
* data
* @param sampleRate
* sampleRate
* @param sampleNegOnly
* sampleNegOnly
* @return null - if the data should be filtered out
* data itself - if the data should not be filtered out
*/
public static List<Object> filter(Integer targetColumnNum, List<String> posTags, List<String> negTags,
List<Object> data, Double sampleRate, Boolean sampleNegOnly) {
String tag = CommonUtils.trimTag(data.get(targetColumnNum).toString());
if(isNotSampled(posTags, negTags, sampleRate, sampleNegOnly, tag)) {
return null;
}
return data;
}
/**
* check whether the fields should be filtered out or not
* the data will be filtered out if
* - the target value is invalid
* - or target tag is not in positive tag list or negative tag list
* - or not be sampled
*
* @param targetColumnNum
* the target column
* @param posTags
* posTags
* @param negTags
* negTags
* @param fields
* fields
* @param sampleRate
* sampleRate
* @param sampleNegOnly
* sampleNegOnly
* @return true - if the data should be filtered out
* false - if the data should not be filtered out
*/
public static boolean filter(int targetColumnNum, List<String> posTags, List<String> negTags, String[] fields,
double sampleRate, boolean sampleNegOnly) {
String tag = CommonUtils.trimTag(fields[targetColumnNum]);
return isNotSampled(posTags, negTags, sampleRate, sampleNegOnly, tag);
}
/**
* To decide whether the data should be filtered out or not. Both unselected data or invalid tag will be
* filtered out.
*
* @param posTags
* posTags
* @param negTags
* negTags
* @param sampleRate
* sampleRate
* @param sampleNegOnly
* sampleNegOnly
* @param tag
* tag
* @return true - if the data should be filtered out
* false - if the data should not be filtered out
*/
public static boolean isNotSampled(List<String> posTags, List<String> negTags, double sampleRate,
boolean sampleNegOnly, String tag) {
if(tag == null) {
log.error("Tag is null.");
return true;
}
if(!(posTags.contains(tag) || negTags.contains(tag))) {
log.error("Invalid target column value - " + tag);
return true;
}
if(sampleNegOnly) {
return (negTags.contains(tag) && rd.nextDouble() > sampleRate);
} else {
return (rd.nextDouble() > sampleRate);
}
}
public static boolean isNotSampled(boolean isBinary, Set<String> tags, Set<String> posTags, Set<String> negTags,
double sampleRate, boolean sampleNegOnly, String tag) {
if(tag == null) {
log.error("Tag is null.");
return true;
}
if(!isBinary && !tags.contains(tag)) {
log.error("Invalid target column value - " + tag);
return true;
}
if(isBinary && !(posTags.contains(tag) || negTags.contains(tag))) {
log.error("Invalid target column value - " + tag);
return true;
}
if(isBinary && sampleNegOnly) {
return (negTags.contains(tag) && rd.nextDouble() > sampleRate);
} else {
return (rd.nextDouble() > sampleRate);
}
}
}