/*
* RapidMiner
*
* Copyright (C) 2001-2008 by Rapid-I and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapid-i.com
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.learner.functions.kernel.jmysvm.kernel;
import java.io.Serializable;
import com.rapidminer.operator.learner.functions.kernel.jmysvm.examples.SVMExample;
import com.rapidminer.operator.learner.functions.kernel.jmysvm.examples.SVMExamples;
import com.rapidminer.operator.learner.functions.kernel.jmysvm.util.Cache;
/**
* Abstract base class for all kernels.
*
* @author Stefan Rueping, Ingo Mierswa
* @version $Id: Kernel.java,v 1.5 2008/07/31 17:43:41 ingomierswa Exp $
*/
public abstract class Kernel implements Serializable {
/**
*
*/
private static final long serialVersionUID = 6086202515099260920L;
/**
* Container for the examples, parameters etc.
*/
protected SVMExamples the_examples;
/**
* dimension of the examples
*/
protected int dim;
/**
* Kernel cache
*/
protected transient Cache kernel_cache;
/**
* Number of elements in cache
*/
protected int kernel_cache_size;
/**
* Size of cache in MB
*/
protected int cache_MB;
/**
* number of examples after shrinking
*/
protected int examples_total;
/**
* Class constructor
*/
public Kernel() {};
/**
* Output as String
*/
public String toString() {
return ("abstract kernel class");
};
/**
* Init the kernel
*
* @param examples
* Container for the examples.
*/
public void init(SVMExamples examples, int cacheSizeMB) {
the_examples = examples;
examples_total = the_examples.count_examples();
dim = the_examples.get_dim();
init_kernel_cache(cacheSizeMB);
};
/**
* Calculates kernel value of vectors x and y
*/
public abstract double calculate_K(int[] x_index, double[] x_att, int[] y_index, double[] y_att);
/**
* calculate inner product
*/
public double innerproduct(int[] x_index, double[] x_att, int[] y_index, double[] y_att) {
double result = 0;
int xpos = x_index.length - 1;
int ypos = y_index.length - 1;
while ((xpos >= 0) && (ypos >= 0)) {
if (x_index[xpos] == y_index[ypos]) {
result += x_att[xpos] * y_att[ypos];
xpos--;
ypos--;
} else if (x_index[xpos] > y_index[ypos]) {
xpos--;
} else {
ypos--;
};
};
return result;
};
/**
* calculate ||x-y||^2
*/
public double norm2(int[] x_index, double[] x_att, int[] y_index, double[] y_att) {
double result = 0;
double tmp;
int xpos = x_index.length - 1;
int ypos = y_index.length - 1;
while ((xpos >= 0) && (ypos >= 0)) {
if (x_index[xpos] == y_index[ypos]) {
tmp = x_att[xpos] - y_att[ypos];
result += tmp * tmp;
xpos--;
ypos--;
} else if (x_index[xpos] > y_index[ypos]) {
tmp = x_att[xpos];
result += tmp * tmp;
xpos--;
} else {
tmp = y_att[ypos];
result += tmp * tmp;
ypos--;
};
};
while (xpos >= 0) {
tmp = x_att[xpos];
result += tmp * tmp;
xpos--;
};
while (ypos >= 0) {
tmp = y_att[ypos];
result += tmp * tmp;
ypos--;
};
return result;
};
/**
* Gets a kernel row
*/
public double[] get_row(int i) {
double[] result = null;
result = ((double[]) kernel_cache.get_element(i));
if (result == null) {
// get last cache element, don't assign new memory
result = (double[]) kernel_cache.get_lru_element();
if (result == null) {
result = new double[examples_total];
};
calculate_K_row(result, i);
kernel_cache.put_element(i, result);
};
return result;
};
/**
* Inits the kernel cache.
*
* @param size
* of the cache in MB
*/
public void init_kernel_cache(int size) {
cache_MB = size;
// array of train_size doubles
kernel_cache_size = size * 1048576 / 4 / examples_total;
if (kernel_cache_size < 1) {
kernel_cache_size = 1;
};
if (kernel_cache_size > the_examples.count_examples()) {
kernel_cache_size = the_examples.count_examples();
};
kernel_cache = new Cache(kernel_cache_size, examples_total);
};
public int getCacheSize() {
return cache_MB;
}
/**
* Sets the number of examples to new value
*/
public void set_examples_size(int new_examples_total) {
// number of rows that fit into cache:
int new_kernel_cache_size = cache_MB * 1048576 / 4 / new_examples_total;
if (new_kernel_cache_size < 1) {
new_kernel_cache_size = 1;
};
if (new_kernel_cache_size > new_examples_total) {
new_kernel_cache_size = new_examples_total;
};
// kernel_cache = new Cache(kernel_cache_size);
if (new_examples_total < examples_total) {
// keep cache
kernel_cache.shrink(new_kernel_cache_size, new_examples_total);
} else if (new_examples_total > examples_total) {
kernel_cache.init(new_kernel_cache_size);
};
kernel_cache_size = new_kernel_cache_size;
examples_total = new_examples_total;
};
/**
* Calculate K(i,j)
*/
public double calculate_K(int i, int j) {
int[] x_index;
double[] x_att;
int[] y_index;
double[] y_att;
x_index = the_examples.index[i];
x_att = the_examples.atts[i];
y_index = the_examples.index[j];
y_att = the_examples.atts[j];
return calculate_K(x_index, x_att, y_index, y_att);
};
public double calculate_K(SVMExample x, SVMExample y) {
return calculate_K(x.index, x.att, y.index, y.att);
};
public double[] calculate_K_row(double[] result, int i) {
int[] x_index;
double[] x_att;
int[] y_index;
double[] y_att;
x_index = the_examples.index[i];
x_att = the_examples.atts[i];
for (int k = 0; k < examples_total; k++) {
y_index = the_examples.index[k];
y_att = the_examples.atts[k];
result[k] = calculate_K(x_index, x_att, y_index, y_att);
};
return result;
};
/**
* swap two training examples
*
* @param pos1
* @param pos2
*/
public void swap(int pos1, int pos2) {
// called after container swap
kernel_cache.swap(pos1, pos2);
};
}