package ml.shifu.shifu.util.updater;
import ml.shifu.shifu.column.NSColumn;
import ml.shifu.shifu.column.NSColumnUtils;
import ml.shifu.shifu.container.obj.ColumnConfig;
import ml.shifu.shifu.container.obj.ModelConfig;
import ml.shifu.shifu.container.obj.ColumnConfig.ColumnType;
import ml.shifu.shifu.core.validator.ModelInspector;
import org.apache.commons.collections.CollectionUtils;
import java.io.IOException;
import java.util.HashSet;
import java.util.Set;
/**
* Created by zhanhu on 2/22/17.
*/
public class BasicUpdater {
protected String targetColumnName;
protected Set<NSColumn> setCategorialColumns;
protected Set<NSColumn> setMeta;
protected Set<NSColumn> setForceRemove;
protected Set<NSColumn> setForceSelect;
protected String weightColumnName;
public BasicUpdater(ModelConfig modelConfig) throws IOException {
this.targetColumnName = modelConfig.getTargetColumnName();
this.weightColumnName = modelConfig.getWeightColumnName();
this.setCategorialColumns = new HashSet<NSColumn>();
if(CollectionUtils.isNotEmpty(modelConfig.getCategoricalColumnNames())) {
for(String column: modelConfig.getCategoricalColumnNames()) {
setCategorialColumns.add(new NSColumn(column));
}
}
this.setMeta = new HashSet<NSColumn>();
if(CollectionUtils.isNotEmpty(modelConfig.getMetaColumnNames())) {
for(String meta: modelConfig.getMetaColumnNames()) {
setMeta.add(new NSColumn(meta));
}
}
this.setForceRemove = new HashSet<NSColumn>();
if(Boolean.TRUE.equals(modelConfig.getVarSelect().getForceEnable())
&& CollectionUtils.isNotEmpty(modelConfig.getListForceRemove())) {
// if we need to update force remove, only and if one the force is enabled
for(String forceRemoveName: modelConfig.getListForceRemove()) {
setForceRemove.add(new NSColumn(forceRemoveName));
}
}
this.setForceSelect = new HashSet<NSColumn>(512);
if(Boolean.TRUE.equals(modelConfig.getVarSelect().getForceEnable())
&& CollectionUtils.isNotEmpty(modelConfig.getListForceSelect())) {
// if we need to update force select, only and if one the force is enabled
for(String forceSelectName: modelConfig.getListForceSelect()) {
setForceSelect.add(new NSColumn(forceSelectName));
}
}
}
public void updateColumnConfig(ColumnConfig columnConfig) {
String varName = columnConfig.getColumnName();
// reset flag at first
columnConfig.setColumnFlag(null);
if(NSColumnUtils.isColumnEqual(this.targetColumnName, varName)) {
columnConfig.setColumnFlag(ColumnConfig.ColumnFlag.Target);
columnConfig.setColumnType(null);
} else if(this.setMeta.contains(new NSColumn(varName))) {
columnConfig.setColumnFlag(ColumnConfig.ColumnFlag.Meta);
columnConfig.setColumnType(null);
} else if(this.setForceRemove.contains(new NSColumn(varName))) {
columnConfig.setColumnFlag(ColumnConfig.ColumnFlag.ForceRemove);
} else if(this.setForceSelect.contains(new NSColumn(varName))) {
columnConfig.setColumnFlag(ColumnConfig.ColumnFlag.ForceSelect);
} else if(NSColumnUtils.isColumnEqual(this.weightColumnName, varName)) {
columnConfig.setColumnFlag(ColumnConfig.ColumnFlag.Weight);
columnConfig.setColumnType(null);
}
if(NSColumnUtils.isColumnEqual(weightColumnName, varName)) {
// weight column is numerical
columnConfig.setColumnType(ColumnType.N);
} else if(NSColumnUtils.isColumnEqual(targetColumnName, varName)) {
// target column is set to categorical column
columnConfig.setColumnType(ColumnType.C);
} else if(setCategorialColumns.contains(new NSColumn(varName))) {
columnConfig.setColumnType(ColumnType.C);
} else {
// meta and other columns are set to numerical if user not set it in categorical column configuration file
columnConfig.setColumnType(ColumnType.N);
}
}
public static BasicUpdater getUpdater(ModelConfig modelConfig, ModelInspector.ModelStep step) throws IOException {
BasicUpdater updater = null;
switch(step) {
case INIT:
case STATS:
updater = new BasicUpdater(modelConfig);
break;
case VARSELECT:
updater = new VarSelUpdater(modelConfig);
break;
case TRAIN:
updater = new TrainUpdater(modelConfig);
break;
default:
updater = new VoidUpdater(modelConfig);
break;
}
return updater;
}
}