/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * RBFKernel.java * Copyright (C) 1999 Eibe Frank * */ package weka.classifiers.sparse; import weka.core.*; /** * The RBF kernel. * K(x, y) = e^-(gamma * <x-y, x-y>^2) * * @author Eibe Frank (eibe@cs.waikato.ac.nz) * @author Shane Legg (shane@intelligenesis.net) (sparse vector code) * @author Stuart Inglis (stuart@reeltwo.com) (sparse vector code) * @author J. Lindgren (jtlindgr{at}cs.helsinki.fi) (RBF kernel) * @version $$ */ public class RBFKernel extends Kernel { /** The precalculated dotproducts of <inst_i,inst_i> */ private double m_kernelPrecalc[]; /** Counts the number of kernel evaluations. */ private int m_kernelEvals = 0; /** The size of the cache (a prime number) */ private int m_cacheSize; /** Kernel cache */ private double[] m_storage; private long[] m_keys; /** Gamma for the RBF kernel. */ private double m_gamma = 0.01; /** The number of instance in the dataset */ private int m_numInsts; /** * Constructor. Initializes m_kernelPrecalc[]. */ public RBFKernel(Instances data, int cacheSize, double gamma) throws Exception { m_gamma = gamma; m_data = data; m_numInsts = m_data.numInstances(); m_cacheSize = cacheSize; m_storage = new double[m_cacheSize]; m_keys = new long[m_cacheSize]; m_kernelPrecalc=new double[data.numInstances()]; for(int i=0;i<data.numInstances();i++) m_kernelPrecalc[i]=dotProd(data.instance(i),data.instance(i)); } /** * Implements the abstract function of Kernel. */ public double eval(int id1, int id2, Instance inst1) throws Exception { double result = 0; long key = -1; int location = -1; // we can only cache if we know the indexes if (id1 >= 0) { if (id1 > id2) { key = (long)id1 * m_numInsts + id2; } else { key = (long)id2 * m_numInsts + id1; } if (key < 0) { throw new Exception("Cache overflow detected!"); } location = (int)(key % m_keys.length); if (m_keys[location] == (key + 1)) { return m_storage[location]; } } Instance inst2 = m_data.instance(id2); double precalc1; if(id1 == -1) precalc1 = dotProd(inst1,inst1); else precalc1 = m_kernelPrecalc[id1]; result = Math.exp(m_gamma * (2. * dotProd(inst1, inst2) - precalc1 - m_kernelPrecalc[id2])); m_kernelEvals++; // store result in cache if (key != -1){ m_storage[location] = result; m_keys[location] = (key + 1); } return result; } /** * Calculates a dot product between two instances * * @param inst1 the first instance * @param inst2 the second instance * @return the dot product of the two instances. * @exception Exception if an error occurs */ private double dotProd(Instance inst1, Instance inst2) throws Exception { double result=0; // we can do a fast dot product int n1 = inst1.numValues(); int n2 = inst2.numValues(); int classIndex = m_data.classIndex(); for (int p1 = 0, p2 = 0; p1 < n1 && p2 < n2;) { int ind1 = inst1.index(p1); int ind2 = inst2.index(p2); if (ind1 == ind2) { if (ind1 != classIndex) { result += inst1.valueSparse(p1) * inst2.valueSparse(p2); } p1++; p2++; } else if (ind1 > ind2) { p2++; } else { p1++; } } return(result); } /** * Frees the cache used by the kernel. */ public void clean(){ m_storage = null; m_keys = null; } /** * Returns the number of time Eval has been called. * * @return the number of kernel evaluation. */ public int numEvals(){ return m_kernelEvals; } }