/*******************************************************************************
* Copyright (c) 2010 Haifeng Li
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
package smile.neighbor;
import org.junit.Before;
import org.junit.Test;
import smile.math.distance.HammingDistance;
import smile.sort.HeapSelect;
import java.io.IOException;
import java.lang.reflect.Array;
import java.util.List;
import java.util.Map;
import java.util.HashMap;
import java.util.ArrayList;
import java.util.LinkedList;
import smile.data.parser.IOUtils;
import static smile.neighbor.SNLSH.simhash64;
/**
* Test data set: http://research.microsoft.com/en-us/downloads/607d14d9-20cd-47e3-85bc-a2f65cd28042/
*
* @author Qiyang Zuo
* @since 03/31/15
*/
public class SNLSHTest {
private class Sentence extends SNLSH.AbstractSentence {
public Sentence(String line) {
this.line = line;
this.tokens = tokenize(line);
}
@Override
List<String> tokenize(String line) {
return tokenize(line, " ");
}
private List<String> tokenize(String line, String regex) {
List<String> tokens = new LinkedList<>();
if (line == null || line.isEmpty()) {
throw new IllegalArgumentException("Line should not be blank!");
}
String[] ss = line.split(regex);
for (String s : ss) {
if (s == null || s.isEmpty()) {
continue;
}
tokens.add(s);
}
return tokens;
}
}
private String[] texts = {
"This is a test case",
"This is another test case",
"This is another test case too",
"I want to be far from other cases"
};
private List<Sentence> testData;
private List<Sentence> trainData;
private List<Sentence> toyData;
private Map<String, Long> signCache; //tokens<->sign
@Before
public void before() throws IOException {
trainData = loadData("msrp/msr_paraphrase_train.txt");
testData = loadData("msrp/msr_paraphrase_test.txt");
signCache = new HashMap<>();
for (Sentence sentence : trainData) {
long sign = simhash64(sentence.tokens);
signCache.put(sentence.line, sign);
}
toyData = new ArrayList<>();
for (String text : texts) {
toyData.add(new Sentence(text));
}
}
private List<Sentence> loadData(String path) throws IOException {
List<Sentence> data = new ArrayList<>();
List<String> lines = IOUtils.readLines(IOUtils.getTestDataReader(path));
for (String line : lines) {
List<String> s = tokenize(line, "\t");
data.add(new Sentence(s.get(s.size() - 1)));
data.add(new Sentence(s.get(s.size() - 2)));
}
return data.subList(2, data.size());
}
private Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence>[] linearKNN(SNLSH.AbstractSentence q, int k) {
@SuppressWarnings("unchecked")
Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence>[] neighbors = (Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence>[])Array.newInstance(Neighbor.class, k);
HeapSelect<Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence>> heap = new HeapSelect<>(neighbors);
Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence> neighbor = new Neighbor<>(null, null, 0, Double.MAX_VALUE);
for (int i = 0; i < k; i++) {
heap.add(neighbor);
}
long sign1 = simhash64(q.tokens);
int hit = 0;
for (Sentence sentence : trainData) {
if(sentence.line.equals(q.line)) {
continue;
}
long sign2 = signCache.get(sentence.line);
double distance = HammingDistance.d(sign1, sign2);
if(distance < heap.peek().distance) {
heap.add(new Neighbor<>(sentence, sentence, 0, distance));
hit++;
}
}
heap.sort();
if (hit < k) {
@SuppressWarnings("unchecked")
Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence>[] n2 = (Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence>[])Array.newInstance(Neighbor.class, hit);
int start = k - hit;
for (int i = 0; i < hit; i++) {
n2[i] = neighbors[i + start];
}
neighbors = n2;
}
return neighbors;
}
private Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence> linearNearest(SNLSH.AbstractSentence q) {
long sign1 = simhash64(q.tokens);
double minDist = Double.MAX_VALUE;
Sentence minKey = null;
for (Sentence sentence : trainData) {
if (sentence.line.equals(q.line)) {
continue;
}
long sign2 = signCache.get(sentence.line);
double distance = HammingDistance.d(sign1, sign2);
if (distance < minDist) {
minDist = distance;
minKey = sentence;
}
}
return new Neighbor<>(minKey, minKey, 0, minDist);
}
private void linearRange(Sentence q, double d, List<Neighbor<SNLSH.AbstractSentence,SNLSH.AbstractSentence>> neighbors) {
long sign1 = simhash64(q.tokens);
for (Sentence sentence : trainData) {
if (sentence.line.equals(q.line)) {
continue;
}
long sign2 = signCache.get(sentence.line);
double distance = HammingDistance.d(sign1, sign2);
if (distance <= d) {
neighbors.add(new Neighbor<>(sentence, sentence, 0, distance));
}
}
}
@Test
public void testKNN() {
SNLSH<SNLSH.AbstractSentence> lsh = createLSH(toyData);
SNLSH.AbstractSentence sentence = new Sentence(texts[0]);
Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence>[] ns = lsh.knn(sentence, 10);
System.out.println("-----test knn: ------");
for (int i = 0; i < ns.length; i++) {
System.out.println("neighbor" + i + " : " + ns[i].key.line + ". distance: " + ns[i].distance);
}
System.out.println("------test knn end------");
}
@Test
public void testKNNRecall() {
SNLSH<SNLSH.AbstractSentence> lsh = createLSH(trainData);
double recall = 0.0;
for (SNLSH.AbstractSentence q : testData) {
int k = 3;
Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence>[] n1 = lsh.knn(q, k);
Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence>[] n2 = linearKNN(q, k);
int hit = 0;
for (int m = 0; m < n1.length && n1[m] != null; m++) {
for (int n = 0; n < n2.length && n2[n] != null; n++) {
if (n1[m].value.equals(n2[n].value)) {
hit++;
break;
}
}
}
recall += 1.0 * hit / k;
}
recall /= testData.size();
System.out.println("SNLSH KNN recall is " + recall);
}
@Test
public void testNearest() {
SNLSH<SNLSH.AbstractSentence> lsh = createLSH(toyData);
System.out.println("----------test nearest start:-------");
Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence> n = lsh.nearest((SNLSH.AbstractSentence)new Sentence(texts[0]));
System.out.println("neighbor" + " : " + n.key.line + " distance: " + n.distance);
System.out.println("----------test nearest end-------");
}
@Test
public void testNearestRecall() {
SNLSH<SNLSH.AbstractSentence> lsh = createLSH(trainData);
double recall = 0.0;
for (SNLSH.AbstractSentence q : testData) {
Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence> n1 = lsh.nearest(q);
Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence> n2 = linearNearest(q);
if (n1.value.equals(n2.value)) {
recall++;
}
}
recall /= testData.size();
System.out.println("SNLSH Nearest recall is " + recall);
}
@Test
public void testRange() {
SNLSH<SNLSH.AbstractSentence> lsh = createLSH(toyData);
List<Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence>> ns = new ArrayList<>();
lsh.range(new Sentence(texts[0]), 10, ns);
System.out.println("-------test range begin-------");
for (Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence> n : ns) {
System.out.println(n.key.line + " distance: " + n.distance);
}
System.out.println("-----test range end ----------");
}
@Test
public void testRangeRecall() {
SNLSH<SNLSH.AbstractSentence> lsh = createLSH(trainData);
double dist = 15.0;
double recall = 0.0;
for (Sentence q : testData) {
List<Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence>> n1 = new ArrayList<>();
lsh.range(q, dist, n1);
List<Neighbor<SNLSH.AbstractSentence, SNLSH.AbstractSentence>> n2 = new ArrayList<>();
linearRange(q, dist, n2);
int hit = 0;
for (int m = 0; m < n1.size(); m++) {
for (int n = 0; n < n2.size(); n++) {
if (n1.get(m).value.equals(n2.get(n).value)) {
hit++;
break;
}
}
}
if (!n2.isEmpty()) {
recall += 1.0 * hit / n2.size();
}
}
recall /= testData.size();
System.out.println("SNLSH range recall is " + recall);
}
private SNLSH<SNLSH.AbstractSentence> createLSH(List<Sentence> data) {
SNLSH<SNLSH.AbstractSentence> lsh = new SNLSH<>(8);
for (Sentence sentence : data) {
lsh.put(sentence, sentence);
}
return lsh;
}
private List<String> tokenize(String line, String regex) {
List<String> tokens = new LinkedList<>();
if (line == null || line.isEmpty()) {
throw new IllegalArgumentException("Line should not be blank!");
}
String[] ss = line.split(regex);
for (String s : ss) {
if (s == null || s.isEmpty()) {
continue;
}
tokens.add(s);
}
return tokens;
}
}