/*
* Apache License
* Version 2.0, January 2004
* http://www.apache.org/licenses/
*
* Copyright 2013 Aurelian Tutuianu
* Copyright 2014 Aurelian Tutuianu
* Copyright 2015 Aurelian Tutuianu
* Copyright 2016 Aurelian Tutuianu
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rapaio.ml.common;
import rapaio.core.SamplingTools;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
/**
* User: Aurelian Tutuianu <paderati@yahoo.com>
*/
public class VarSelector implements Serializable {
private static final long serialVersionUID = -6800363806127327947L;
private static final int M_ALL = 0;
private static final int M_AUTO = -1;
public static VarSelector ALL = new VarSelector(M_ALL);
public static VarSelector AUTO = new VarSelector(M_AUTO);
private final int mVars;
private int mCount = 0;
private Set<String> varNames = new HashSet<>();
public VarSelector() {
this(M_ALL);
}
public VarSelector(int mVars) {
this.mVars = mVars;
}
public VarSelector newInstance() {
VarSelector sel = new VarSelector(mVars);
sel.varNames.addAll(varNames);
return sel;
}
public VarSelector withVarNames(final String... varNames) {
this.varNames = Arrays.stream(varNames).collect(Collectors.toSet());
if (mVars == M_ALL) {
this.mCount = this.varNames.size();
} else if (mVars == M_AUTO) {
this.mCount = Math.max((int) Math.sqrt(this.varNames.size()), 1);
} else {
this.mCount = mVars;
}
return this;
}
public String name() {
if (mVars == M_ALL)
return "VarSelector[ALL]";
if (mVars == M_AUTO)
return "VarSelector[AUTO]";
return "VarSelector[" + mVars + "]";
}
public String[] nextVarNames() {
if (mVars == M_ALL) {
return varNames.toArray(new String[varNames.size()]);
}
int m = Math.min(mCount, varNames.size());
int[] indexes = SamplingTools.sampleWOR(varNames.size(), m);
String[] result = new String[m];
String[] arr = varNames.toArray(new String[varNames.size()]);
for (int i = 0; i < indexes.length; i++) {
result[i] = arr[indexes[i]];
}
return result;
}
public String[] nextAllVarNames() {
int m = varNames.size();
int[] indexes = SamplingTools.sampleWOR(varNames.size(), m);
String[] result = new String[m];
String[] arr = varNames.toArray(new String[varNames.size()]);
for (int i = 0; i < indexes.length; i++) {
result[i] = arr[indexes[i]];
}
return result;
}
public int mCount() {
return mCount;
}
public void removeVarNames(Collection<String> varName) {
this.varNames.removeAll(varName);
}
public void addVarNames(Collection<String> varName) {
this.varNames.addAll(varName);
}
}