/*
* Copyright [2012-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ml.shifu.shifu.container.obj.ColumnConfig;
import ml.shifu.shifu.container.obj.ColumnConfig.ColumnConfigComparator;
import ml.shifu.shifu.container.obj.ModelConfig;
import ml.shifu.shifu.core.dtrain.DTrainUtils;
import ml.shifu.shifu.util.CommonUtils;
/**
* variable selector
*/
public class VariableSelector {
private static Logger log = LoggerFactory.getLogger(VariableSelector.class);
private ModelConfig modelConfig;
private List<ColumnConfig> columnConfigList;
private double[] epsilonArray = new double[] { 0.01d, 0.05d };
public VariableSelector(ModelConfig modelConfig, List<ColumnConfig> columnConfigList) {
log.info("Creating VariableSelector...");
this.modelConfig = modelConfig;
this.columnConfigList = columnConfigList;
this.epsilonArray = modelConfig.getVarSelect().getEpsilons();
}
// TODO it should support some DSL like "KS > 2 and IV and PSI <= 0.1"
public List<ColumnConfig> selectByFilter(String input) {
// VarSelParser parser = new VarSelParser(new CommonTokenStream(new VarSelLexer(new ANTLRInputStream(input))));
return null;
}
public static class Tuple {
public int columnNum;
public double ks;
public double iv;
public double[] box;
public Tuple(int columnNum, double ks, double iv) {
this.columnNum = columnNum;
this.ks = ks;
this.iv = iv;
}
public static Tuple clone(Tuple tuple) {
Tuple newOne = new Tuple(tuple.columnNum, tuple.ks, tuple.iv);
if(tuple.box != null) {
double[] newBox = new double[tuple.box.length];
for(int i = 0; i < newBox.length; i++) {
newBox[i] = tuple.box[i];
}
newOne.box = newBox;
}
return newOne;
}
/*
* (non-Javadoc)
*
* @see java.lang.Object#toString()
*/
@Override
public String toString() {
return "Tuple [columnNum=" + columnNum + ", ks=" + ks + ", iv=" + iv + ", box=" + Arrays.toString(box)
+ "]";
}
}
public static void setFilterNumberByFilterOutRatio(ModelConfig modelConfig, List<ColumnConfig> columnConfigList) {
// if user already set filter number then ignore filter out ratio
if(modelConfig.getVarSelectFilterNum() > 0) {
return;
}
int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(columnConfigList);
int inputNodeCount = inputOutputIndex[0] == 0 ? inputOutputIndex[2] : inputOutputIndex[0];
Float filterOutRatio = modelConfig.getVarSelect().getFilterOutRatio();
int targetCnt = (int) (inputNodeCount * (1.0f - filterOutRatio));
modelConfig.getVarSelect().setFilterNum(targetCnt);
}
// return the list of selected column nums
public List<ColumnConfig> selectByFilter() {
log.info(" - Method: Filter");
int ptrKs = 0, ptrIv = 0, ptrPareto = 0, cntByForce = 0;
VariableSelector.setFilterNumberByFilterOutRatio(this.modelConfig, this.columnConfigList);
log.info("Start Variable Selection...");
log.info("\t VarSelectEnabled: " + modelConfig.getVarSelectFilterEnabled());
log.info("\t VarSelectBy: " + modelConfig.getVarSelectFilterBy());
log.info("\t VarSelectNum: " + modelConfig.getVarSelectFilterNum());
List<Integer> selectedColumnNumList = new ArrayList<Integer>();
List<ColumnConfig> ksList = new ArrayList<ColumnConfig>();
List<ColumnConfig> ivList = new ArrayList<ColumnConfig>();
List<Tuple> paretoList = new ArrayList<Tuple>();
int cntSelected = 0;
for(ColumnConfig config: this.columnConfigList) {
if(config.isMeta() || config.isTarget()) {
log.info("\t Skip meta, weight or target column: " + config.getColumnName());
} else if(config.isForceRemove()) {
log.info("\t ForceRemove: " + config.getColumnName());
} else if(config.isForceSelect()) {
log.info("\t ForceSelect: " + config.getColumnName());
if(config.getMean() == null || config.getStdDev() == null) {
// TODO - check the mean of categorical variable could be null
log.info("\t ForceSelect Failed: mean/stdDev must not be null");
} else {
selectedColumnNumList.add(config.getColumnNum());
cntSelected++;
cntByForce++;
}
} else if(!CommonUtils.isGoodCandidate(config)) {
log.info("\t Incomplete info(please check KS, IV, Mean, or StdDev fields): " + config.getColumnName());
} else if((config.isCategorical() && !modelConfig.isCategoricalDisabled()) || config.isNumerical()) {
ksList.add(config);
ivList.add(config);
if(config != null && config.getColumnStats() != null) {
Double ks = config.getKs();
Double iv = config.getIv();
paretoList.add(new Tuple(config.getColumnNum(), ks == null ? 0d : ks, iv == null ? 0d : iv));
}
}
}
// not enabled filter, so only select forceSelect columns
if(!this.modelConfig.getVarSelectFilterEnabled()) {
log.info("Summary:");
log.info("\tSelected Variables: " + cntSelected);
if(cntByForce != 0) {
log.info("\t- By Force: " + cntByForce);
}
for(int n: selectedColumnNumList) {
this.columnConfigList.get(n).setFinalSelect(true);
}
return columnConfigList;
}
String key = this.modelConfig.getVarSelectFilterBy();
Collections.sort(ksList, new ColumnConfigComparator("ks"));
Collections.sort(ivList, new ColumnConfigComparator("iv"));
List<Tuple> newParetoList = sortByPareto(paretoList);
int expectedVarNum = Math.min(ksList.size(), modelConfig.getVarSelectFilterNum());
log.info("Expected selected columns:" + expectedVarNum);
// reset to false at first.
for(ColumnConfig columnConfig: this.columnConfigList) {
if(columnConfig.isFinalSelect()) {
columnConfig.setFinalSelect(false);
}
}
ColumnConfig config = null;
while(cntSelected < expectedVarNum) {
if(key.equalsIgnoreCase("ks")) {
config = ksList.get(ptrKs);
selectedColumnNumList.add(config.getColumnNum());
ptrKs++;
log.info("\t SelectedByKS=" + config.getKs() + "(Rank=" + ptrKs + "): " + config.getColumnName());
cntSelected++;
} else if(key.equalsIgnoreCase("iv")) {
config = ivList.get(ptrIv);
selectedColumnNumList.add(config.getColumnNum());
ptrIv++;
log.info("\t SelectedByIV=" + config.getIv() + "(Rank=" + ptrIv + "): " + config.getColumnName());
cntSelected++;
} else if(key.equalsIgnoreCase("mix")) {
config = ksList.get(ptrKs);
if(selectedColumnNumList.contains(config.getColumnNum())) {
log.info("\t Variable Already Selected: " + config.getColumnName());
ptrKs++;
} else {
selectedColumnNumList.add(config.getColumnNum());
ptrKs++;
log.info("\t SelectedByKS=" + config.getKs() + "(Rank=" + ptrKs + "): " + config.getColumnName());
cntSelected++;
}
if(cntSelected == expectedVarNum) {
break;
}
config = ivList.get(ptrIv);
if(selectedColumnNumList.contains(config.getColumnNum())) {
log.info("\t Variable Already Selected: " + config.getColumnName());
ptrIv++;
} else {
selectedColumnNumList.add(config.getColumnNum());
ptrIv++;
log.info("\t SelectedByIV=" + config.getIv() + "(Rank=" + ptrIv + "): " + config.getColumnName());
cntSelected++;
}
} else if(key.equalsIgnoreCase("pareto")) {
if(ptrPareto >= newParetoList.size()) {
config = ksList.get(ptrKs);
if(selectedColumnNumList.contains(config.getColumnNum())) {
log.info("\t Variable Already Selected: " + config.getColumnName());
} else {
selectedColumnNumList.add(config.getColumnNum());
log.info("\t SelectedByKS=" + config.getKs() + "(Rank=" + ptrKs + newParetoList.size() + "): "
+ config.getColumnName());
cntSelected++;
}
ptrKs++;
} else {
int columnNum = newParetoList.get(ptrPareto).columnNum;
selectedColumnNumList.add(columnNum);
log.info("\t SelectedByPareto " + columnConfigList.get(columnNum).getColumnName());
ptrPareto++;
cntSelected++;
}
}
}
log.info("Summary:");
log.info("\t Selected Variables: " + cntSelected);
if(cntByForce != 0) {
log.info("\t - By Force: " + cntByForce);
}
if(ptrPareto != 0) {
log.info("\t - By Pareto: " + ptrPareto);
}
if(ptrKs != 0) {
log.info("\t - By KS: " + ptrKs);
}
if(ptrIv != 0) {
log.info("\t - By IV: " + ptrIv);
}
// update column config list and set finalSelect to true
for(int n: selectedColumnNumList) {
this.columnConfigList.get(n).setFinalSelect(true);
}
return columnConfigList;
}
private static class Archives {
public double[] epsilons;
public List<Tuple> tuples = new ArrayList<VariableSelector.Tuple>();
public Archives(double[] epsilons) {
this.epsilons = epsilons;
}
public void sortInto(Tuple currTuple) {
double[] eBox = new double[epsilons.length];
for(int i = 0; i < epsilons.length; i++) {
if(i == 0) {
eBox[i] = Math.floor(currTuple.ks / epsilons[i]);
} else {
eBox[i] = Math.floor(currTuple.iv / epsilons[i]);
}
}
// System.out.println(Arrays.toString(eBox));
int currSize = tuples.size();
int index = -1;
while(index < currSize - 1) {
index += 1;
boolean adominate = false; // # archive dominates
boolean sdominate = false; // # solution dominates
boolean nondominate = false; // # neither dominates
Tuple indexTuple = tuples.get(index);
double[] aBox = indexTuple.box;
// System.out.println(Arrays.toString(aBox));
for(int i = 0; i < epsilons.length; i++) {
if(aBox[i] < eBox[i]) {
adominate = true;
if(sdominate) {
nondominate = true;
break; // for;
}
} else if(aBox[i] > eBox[i]) {
sdominate = true;
if(adominate) { // # nondomination
nondominate = true;
break;// # for
}
}
}
if(nondominate) {
continue;
}// # while
if(adominate) {// # candidate solution was dominated
return;
}
if(sdominate) { // # candidate solution dominated archive solution
// System.out.println(currTuple.columnNum + " " + currTuple.ks + " " + currTuple.iv);
// System.out.println(index);
this.tuples.remove(index);
index -= 1;
currSize -= 1;
continue; // # while
}
// # solutions are in the same box
indexTuple = tuples.get(index);
// corner = [ebox[ii] * self.epsilons[ii] for ii in self.itobj]
double[] corner = new double[epsilons.length];
for(int j = 0; j < corner.length; j++) {
corner[j] = eBox[j] * epsilons[j];
}
double sdist = 0d, adist = 0d;
for(int j = 0; j < corner.length; j++) {
if(j == 0) {
sdist += (currTuple.ks - corner[j]) * (currTuple.ks - corner[j]);
adist += (indexTuple.ks - corner[j]) * (indexTuple.ks - corner[j]);
} else {
sdist += (currTuple.iv - corner[j]) * (currTuple.iv - corner[j]);
adist += (indexTuple.iv - corner[j]) * (indexTuple.iv - corner[j]);
}
}
if(adist < sdist) {// # archive dominates
return;
} else { // : # solution dominates
this.tuples.remove(index);
index -= 1;
currSize -= 1;
continue; // # while
}
}
// if you get here, then no archive solution has dominated this one
currTuple.box = eBox;
tuples.add(Tuple.clone(currTuple));
}
}
public List<Tuple> sortByPareto(List<Tuple> paretoList) {
// TODO
if(this.epsilonArray == null) {
this.epsilonArray = new double[] { 0.01d, 0.05d };
}
Archives ar = new Archives(this.epsilonArray);
for(Tuple tuple: paretoList) {
ar.sortInto(tuple);
}
return ar.tuples;
}
public List<Tuple> sortByParetoCC(List<ColumnConfig> list) {
if(this.epsilonArray == null) {
this.epsilonArray = new double[] { 0.01d, 0.05d };
}
List<Tuple> tuples = new ArrayList<VariableSelector.Tuple>();
for(ColumnConfig columnConfig: list) {
if(columnConfig != null && columnConfig.getColumnStats() != null) {
Double ks = columnConfig.getKs();
Double iv = columnConfig.getIv();
tuples.add(new Tuple(columnConfig.getColumnNum(), ks == null ? 0d : ks, iv == null ? 0d : 0 - iv));
}
}
return sortByPareto(tuples);
}
}