/**
* Copyright (C) 2001-2017 by RapidMiner and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapidminer.com
*
* This program is free software: you can redistribute it and/or modify it under the terms of the
* GNU Affero General Public License as published by the Free Software Foundation, either version 3
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
* even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along with this program.
* If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.learner.tree;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import org.apache.commons.lang.ArrayUtils;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.studio.internal.Resources;
import com.rapidminer.tools.Tools;
/**
* Handles selections of attributes and examples of a {@link ColumnExampleTable}. Creates start
* selections and updates them.
*
* @author Gisa Schaefer
*
*/
public class SelectionCreator {
private ColumnExampleTable columnTable;
public SelectionCreator(ColumnExampleTable columnTable) {
this.columnTable = columnTable;
}
/**
* Creates an example index start selection for each numerical attribute, or if there is none,
* only one.
*
* @return a map containing for each numerical attribute an example index array such that the
* associated attribute values are in ascending order.
*/
public Map<Integer, int[]> getStartSelection() {
Map<Integer, int[]> selection = new HashMap<>();
if (columnTable.getNumberOfRegularNumericalAttributes() == 0) {
selection.put(0, createFullArray(columnTable.getNumberOfExamples()));
} else {
Integer[] bigSelectionArray = createFullBigArray(columnTable.getNumberOfExamples());
for (int j = columnTable.getNumberOfRegularNominalAttributes(); j < columnTable
.getTotalNumberOfRegularAttributes(); j++) {
final double[] attributeColumn = columnTable.getNumericalAttributeColumn(j);
Integer[] startSelection = Arrays.copyOf(bigSelectionArray, bigSelectionArray.length);
Arrays.sort(startSelection, new Comparator<Integer>() {
@Override
public int compare(Integer a, Integer b) {
return Double.compare(attributeColumn[a], attributeColumn[b]);
}
});
selection.put(j, ArrayUtils.toPrimitive(startSelection));
}
}
return selection;
}
/**
* Creates in parallel an example index start selection for each numerical attribute, or if
* there is none, only one.
*
* @param operator
* the operator for which the calculation is done
* @return a map containing for each numerical attribute an example index array such that the
* associated attribute values are in ascending order.
* @throws OperatorException
*/
public Map<Integer, int[]> getStartSelectionParallel(Operator operator) throws OperatorException {
Map<Integer, int[]> selection = new HashMap<>();
if (columnTable.getNumberOfRegularNumericalAttributes() == 0) {
selection.put(0, createFullArray(columnTable.getNumberOfExamples()));
} else {
List<Callable<int[]>> tasks = new ArrayList<Callable<int[]>>();
final Integer[] bigSelectionArray = createFullBigArray(columnTable.getNumberOfExamples());
for (int j = columnTable.getNumberOfRegularNominalAttributes(); j < columnTable
.getTotalNumberOfRegularAttributes(); j++) {
final double[] attributeColumn = columnTable.getNumericalAttributeColumn(j);
tasks.add(new Callable<int[]>() {
@Override
public int[] call() {
Integer[] startSelection = Arrays.copyOf(bigSelectionArray, bigSelectionArray.length);
Arrays.sort(startSelection, new Comparator<Integer>() {
@Override
public int compare(Integer a, Integer b) {
return Double.compare(attributeColumn[a], attributeColumn[b]);
}
});
return ArrayUtils.toPrimitive(startSelection);
}
});
}
List<int[]> results = null;
try {
results = Resources.getConcurrencyContext(operator).call(tasks);
} catch (ExecutionException e) {
Throwable cause = e.getCause();
if (cause instanceof RuntimeException) {
throw (RuntimeException) cause;
} else if (cause instanceof Error) {
throw (Error) cause;
} else {
throw new OperatorException(cause.getMessage(), cause);
}
}
for (int j = columnTable.getNumberOfRegularNominalAttributes(); j < columnTable
.getTotalNumberOfRegularAttributes(); j++) {
selection.put(j, results.get(j - columnTable.getNumberOfRegularNominalAttributes()));
}
}
return selection;
}
/**
* Splits the selected examples according to the bestAttribute and, if the attribute is
* numerical, the bestSplitValue.
*
* @param allSelectedExamples
* @param bestAttribute
* @param bestSplitValue
* @return a collection of maps mapping the numerical attribute number to the sorted array
* containing the selected example numbers
*/
public Collection<Map<Integer, int[]>> getSplits(Map<Integer, int[]> allSelectedExamples, int bestAttribute,
double bestSplitValue) {
Collection<Map<Integer, int[]>> splits;
if (columnTable.representsNominalAttribute(bestAttribute)) {
splits = calculateSplits(allSelectedExamples, bestAttribute);
} else {
splits = calculateSplits(allSelectedExamples, bestAttribute, bestSplitValue);
}
return splits;
}
/**
* Splits for every numerical attribute the sorted index array according to the bestSplitValue
* at the bestAttribute. Groups by smaller or equal to bestSplitValue, greater than
* bestSplitValue and value is NaN.
*
* @param allSelectedExamples
* @param bestAttribute
* @param bestSplitValue
* @return a list containing first the example number where the value is smaller than
* bestSplitValue, then the ones greater, then the NaNs
*/
public Collection<Map<Integer, int[]>> calculateSplits(Map<Integer, int[]> allSelectedExamples, int bestAttribute,
double bestSplitValue) {
double[] attributeColumn = columnTable.getNumericalAttributeColumn(bestAttribute);
List<Map<Integer, int[]>> results = new ArrayList<>(3);
results.add(0, new HashMap<Integer, int[]>());
results.add(1, new HashMap<Integer, int[]>());
boolean existNaNs = false;
// check if the selectedExamples contain NaN values of the attribute Column - because of
// sorting they should be at the end
if (Double
.isNaN(attributeColumn[allSelectedExamples.get(bestAttribute)[allSelectedExamples.get(bestAttribute).length - 1]])) {
existNaNs = true;
results.add(2, new HashMap<Integer, int[]>());
}
int maximalLength = getArbitraryValue(allSelectedExamples).length;
int[] smaller = new int[maximalLength];
int[] bigger = new int[maximalLength];
int[] naNs = new int[maximalLength];
double value;
for (int i : allSelectedExamples.keySet()) {
int smallerPosition = 0;
int biggerPosition = 0;
int naNsPosition = 0;
int[] selectedExamples = allSelectedExamples.get(i);
for (int j : selectedExamples) {
value = attributeColumn[j];
if (Double.isNaN(value)) {
naNs[naNsPosition] = j;
naNsPosition++;
} else if (Tools.isLessEqual(value, bestSplitValue)) {
smaller[smallerPosition] = j;
smallerPosition++;
} else {
bigger[biggerPosition] = j;
biggerPosition++;
}
}
results.get(0).put(i, Arrays.copyOf(smaller, smallerPosition));
results.get(1).put(i, Arrays.copyOf(bigger, biggerPosition));
if (existNaNs) {
results.get(2).put(i, Arrays.copyOf(naNs, naNsPosition));
}
}
return results;
}
/**
* Splits for every numerical attribute the sorted index array according to the value at the
* best attribute. Groups the splitted arrays by the value at the best attribute.
*
* @param allSelectedExamples
* @param bestAttribute
* @return
*/
public Collection<Map<Integer, int[]>> calculateSplits(Map<Integer, int[]> allSelectedExamples, int bestAttribute) {
byte[] attributeColumn = columnTable.getNominalAttributeColumn(bestAttribute);
Map<Byte, Map<Integer, int[]>> results = new HashMap<>();
Map<Byte, List<Integer>> valueLists;
byte value;
for (int i : allSelectedExamples.keySet()) {
valueLists = new HashMap<>();
int[] selectedExamples = allSelectedExamples.get(i);
for (int j : selectedExamples) {
// put j in the list associated to its value
value = attributeColumn[j];
if (valueLists.containsKey(value)) {
valueLists.get(value).add(j);
} else {
List<Integer> temp = new ArrayList<>();
temp.add(j);
valueLists.put(value, temp);
}
}
// store the pair (key, list) as (key, (i,array(list))
for (Byte key : valueLists.keySet()) {
List<Integer> list = valueLists.get(key);
int[] temp = ArrayUtils.toPrimitive(list.toArray(new Integer[list.size()]));
if (results.containsKey(key)) {
results.get(key).put(i, temp);
} else {
Map<Integer, int[]> toadd = new HashMap<>();
toadd.put(i, temp);
results.put(key, toadd);
}
}
}
return results.values();
}
/**
* If the bestAttribute is nominal, its number is removed from the selectedAttributes, otherwise
* it stays the same.
*
* @param selectedAttributes
* @param bestAttribute
* @return
*/
public int[] updateRemainingAttributes(int[] selectedAttributes, int bestAttribute) {
int[] remainingAttributes;
if (columnTable.representsNominalAttribute(bestAttribute)) {
remainingAttributes = removeAttribute(bestAttribute, selectedAttributes);
} else {
remainingAttributes = selectedAttributes;
}
return remainingAttributes;
}
/**
* Creates a new array containing all entries of selectedAttributes except for
* attributeNumberToDelete.
*
* @param attributeNumberToDelete
* @param selectedAttributes
* @return
*/
public int[] removeAttribute(int attributeNumberToDelete, int[] selectedAttributes) {
int[] newSelection = new int[selectedAttributes.length - 1];
int j = 0;
for (int i : selectedAttributes) {
if (i != attributeNumberToDelete) {
newSelection[j] = i;
j++;
}
}
return newSelection;
}
/**
* Create a selection array containing all rows, i.e. containing all consecutive numbers
* [0..length-1]
*
* @param length
* @return
*/
public int[] createFullArray(int length) {
int[] fullSelection = new int[length];
for (int i = 0; i < length; i++) {
fullSelection[i] = i;
}
return fullSelection;
}
/**
* Create an Integer array containing all consecutive numbers [0..length-1]
*
* @param length
* @return
*/
public Integer[] createFullBigArray(int length) {
Integer[] fullSelection = new Integer[length];
for (int i = 0; i < length; i++) {
fullSelection[i] = i;
}
return fullSelection;
}
/**
* Returns a value of the map.
*
* @param map
* a non-empty map
* @return
*/
public static int[] getArbitraryValue(Map<Integer, int[]> map) {
return map.values().iterator().next();
}
}