/**
* Copyright (C) 2001-2017 by RapidMiner and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapidminer.com
*
* This program is free software: you can redistribute it and/or modify it under the terms of the
* GNU Affero General Public License as published by the Free Software Foundation, either version 3
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
* even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along with this program.
* If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.datatable;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import com.rapidminer.example.Attribute;
import com.rapidminer.operator.learner.functions.kernel.KernelModel;
/**
* This class can be used to use a kernel model as data table. The data is directly read from the
* kernel model instead of building a copy. Please note that the method for adding new rows is not
* supported by this type of data tables.
*
* @author Ingo Mierswa
*/
public class DataTableKernelModelAdapter extends AbstractDataTable {
/** Helper class to iterated over the examples or support vectors of a {@link KernelModel}. */
private static class KernelModelIterator implements Iterator<DataTableRow> {
private int counter = 0;
private DataTableKernelModelAdapter adapter;
public KernelModelIterator(DataTableKernelModelAdapter adapter) {
this.adapter = adapter;
}
@Override
public boolean hasNext() {
return counter < adapter.getNumberOfRows();
}
@Override
public DataTableRow next() {
DataTableRow row = adapter.getRow(counter);
counter++;
return row;
}
@Override
public void remove() {
throw new RuntimeException("DataTable.KernelModelIterator: remove not supported!");
}
}
private KernelModel kernelModel;
private String[] attributeNames;
private static final String DEFAULT_REGULAR_ATTRIBUTE_PREFIX = "attribute";
private int[] sampleMapping = null;
private Map<Integer, String> index2LabelMap = new HashMap<>();
private Map<String, Integer> label2IndexMap = new HashMap<>();
public DataTableKernelModelAdapter(KernelModel kernelModel) {
super("Kernel Model Support Vectors");
this.kernelModel = kernelModel;
int labelCounter = 0;
if (this.kernelModel.isClassificationModel()) {
for (int i = 0; i < this.kernelModel.getNumberOfSupportVectors(); i++) {
String label = this.kernelModel.getClassificationLabel(i);
if (label2IndexMap.get(label) == null) {
this.label2IndexMap.put(label, labelCounter);
this.index2LabelMap.put(labelCounter, label);
labelCounter++;
}
}
}
// storing attribute names
attributeNames = new String[kernelModel.getTrainingHeader().getAttributes().size()];
int i = 0;
for (Attribute attribute : kernelModel.getTrainingHeader().getAttributes()) {
attributeNames[i] = attribute.getName();
i++;
}
}
public DataTableKernelModelAdapter(DataTableKernelModelAdapter dataTableKernelModelAdapter) {
super(dataTableKernelModelAdapter.getName());
this.kernelModel = dataTableKernelModelAdapter.kernelModel; // shallow clone
this.index2LabelMap = dataTableKernelModelAdapter.index2LabelMap; // shallow clone
this.label2IndexMap = dataTableKernelModelAdapter.label2IndexMap; // shallow clone
this.sampleMapping = null;
}
@Override
public int getNumberOfSpecialColumns() {
return KernelModelRow2DataTableRowWrapper.NUMBER_OF_SPECIAL_COLUMNS;
}
@Override
public boolean isSpecial(int index) {
return index < KernelModelRow2DataTableRowWrapper.NUMBER_OF_SPECIAL_COLUMNS;
}
@Override
public boolean isNominal(int index) {
if (index == KernelModelRow2DataTableRowWrapper.LABEL) {
return this.kernelModel.isClassificationModel();
} else {
return index == KernelModelRow2DataTableRowWrapper.SUPPORT_VECTOR;
}
}
@Override
public boolean isDate(int index) {
return false;
}
@Override
public boolean isTime(int index) {
return false;
}
@Override
public boolean isDateTime(int index) {
return false;
}
@Override
public boolean isNumerical(int index) {
return !isNominal(index);
}
@Override
public String mapIndex(int column, int value) {
if (column == KernelModelRow2DataTableRowWrapper.LABEL && this.kernelModel.isClassificationModel()) {
return index2LabelMap.get(value);
} else if (column == KernelModelRow2DataTableRowWrapper.SUPPORT_VECTOR) {
if (value == 0) {
return "no support vector";
} else {
return "support vector";
}
} else {
return null;
}
}
@Override
public int mapString(int column, String value) {
if (column == KernelModelRow2DataTableRowWrapper.LABEL && this.kernelModel.isClassificationModel()) {
return label2IndexMap.get(value);
} else if (column == KernelModelRow2DataTableRowWrapper.SUPPORT_VECTOR) {
if ("no support vector".equals(value)) {
return 0;
} else {
return 1;
}
} else {
return -1;
}
}
@Override
public int getNumberOfValues(int column) {
if (column == KernelModelRow2DataTableRowWrapper.LABEL && this.kernelModel.isClassificationModel()) {
return index2LabelMap.size();
} else if (column == KernelModelRow2DataTableRowWrapper.SUPPORT_VECTOR) {
return 2;
} else {
return -1;
}
}
@Override
public String getColumnName(int i) {
if (i < KernelModelRow2DataTableRowWrapper.NUMBER_OF_SPECIAL_COLUMNS) {
return KernelModelRow2DataTableRowWrapper.SPECIAL_COLUMN_NAMES[i];
} else {
int attributeIndex = i - KernelModelRow2DataTableRowWrapper.NUMBER_OF_SPECIAL_COLUMNS;
if (attributeIndex >= 0 && attributeIndex <= attributeNames.length) {
return attributeNames[attributeIndex];
}
return DEFAULT_REGULAR_ATTRIBUTE_PREFIX + (attributeIndex + 1);
}
}
@Override
public int getColumnIndex(String name) {
for (int i = 0; i < KernelModelRow2DataTableRowWrapper.NUMBER_OF_SPECIAL_COLUMNS; i++) {
if (KernelModelRow2DataTableRowWrapper.SPECIAL_COLUMN_NAMES[i].equals(name)) {
return i;
}
}
for (int i = 0; i < attributeNames.length; i++) {
if (attributeNames[i].equals(name)) {
return i;
}
}
if (name.startsWith(DEFAULT_REGULAR_ATTRIBUTE_PREFIX)) {
return Integer.parseInt(name.substring(DEFAULT_REGULAR_ATTRIBUTE_PREFIX.length()))
+ KernelModelRow2DataTableRowWrapper.NUMBER_OF_SPECIAL_COLUMNS - 1;
}
return -1;
}
@Override
public boolean isSupportingColumnWeights() {
return false;
}
@Override
public double getColumnWeight(int column) {
return Double.NaN;
}
@Override
public int getNumberOfColumns() {
return kernelModel.getNumberOfAttributes() + KernelModelRow2DataTableRowWrapper.NUMBER_OF_SPECIAL_COLUMNS;
}
@Override
public int getNumberOfRows() {
if (this.sampleMapping == null) {
return this.kernelModel.getNumberOfSupportVectors();
} else {
return this.sampleMapping.length;
}
}
@Override
public void add(DataTableRow row) {
throw new RuntimeException("DataTableKernelModelAdapter: adding new rows is not supported!");
}
@Override
public DataTableRow getRow(int index) {
if (this.sampleMapping == null) {
return new KernelModelRow2DataTableRowWrapper(this.kernelModel, this, index);
} else {
return new KernelModelRow2DataTableRowWrapper(this.kernelModel, this, this.sampleMapping[index]);
}
}
@Override
public Iterator<DataTableRow> iterator() {
return new KernelModelIterator(this);
}
@Override
public DataTable sample(int newSize) {
DataTableKernelModelAdapter result = new DataTableKernelModelAdapter(this);
double ratio = (double) newSize / (double) getNumberOfRows();
Random random = new Random(2001);
List<Integer> usedRows = new LinkedList<>();
for (int i = 0; i < getNumberOfRows(); i++) {
if (random.nextDouble() <= ratio) {
int index = i;
if (this.sampleMapping != null) {
index = this.sampleMapping[index];
}
usedRows.add(index);
}
}
int[] sampleMapping = new int[usedRows.size()];
int counter = 0;
Iterator<Integer> i = usedRows.iterator();
while (i.hasNext()) {
sampleMapping[counter++] = i.next();
}
result.sampleMapping = sampleMapping;
return result;
}
}