/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
/** Interface for a measure of distance between two <CODE>SparseVector</CODE>s
@author Aron Culotta <A HREF="mailto:culotta@cs.umass.edu">culotta@cs.umass.edu</A>
*/
package cc.mallet.types;
import java.util.HashMap;
import cc.mallet.types.SparseVector;
/**
Computes
1 - [<x,y> / sqrt (<x,x>*<y,y>)]
aka 1 - cosine similarity
*/
public class NormalizedDotProductMetric implements CachedMetric {
HashMap hash; // stores the self dot-products used for normalization
public NormalizedDotProductMetric () {
this.hash = new HashMap ();
}
public double distance (SparseVector a, SparseVector b) {
// double ret = a.dotProduct (b) /
// Math.sqrt (a.dotProduct (a) * b.dotProduct (b));
// gmann : twoNorm() more efficient than a.dotProduct(a)
double ret = a.dotProduct(b) / (a.twoNorm()*b.twoNorm());
return 1.0 - ret;
}
public double distance( SparseVector a, int hashCodeA,
SparseVector b, int hashCodeB) {
Double cachedA = (Double) hash.get (new Integer (hashCodeA));
Double cachedB = (Double) hash.get (new Integer (hashCodeB));
if (a == null || b == null)
return 1.0;
if (cachedA == null) {
cachedA = new Double (a.dotProduct (a));
hash.put (new Integer (hashCodeA), cachedA);
}
if (cachedB == null) {
cachedB = new Double (b.dotProduct (b));
hash.put (new Integer (hashCodeB), cachedB);
}
double ab = a.dotProduct (b);
if (cachedA == null || cachedB == null) {
throw new IllegalStateException ("cachedValues null");
}
double ret = a.dotProduct (b) / Math.sqrt (cachedA.doubleValue()*cachedB.doubleValue());
return 1.0 - ret;
}
}