package edu.uncc.cs.watsonsim.search;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.fusesource.lmdbjni.BufferCursor;
import org.fusesource.lmdbjni.Database;
import org.fusesource.lmdbjni.Entry;
import org.fusesource.lmdbjni.Env;
import org.fusesource.lmdbjni.Transaction;
import static org.fusesource.lmdbjni.Constants.*;
import edu.uncc.cs.watsonsim.Environment;
import edu.uncc.cs.watsonsim.KV;
import edu.uncc.cs.watsonsim.Passage;
import edu.uncc.cs.watsonsim.Phrase;
import edu.uncc.cs.watsonsim.Question;
import edu.uncc.cs.watsonsim.nlp.DenseVectors;
public class MeanDVSearch extends Searcher {
public static final int K = 20; // How many results to return
public static final int N = DenseVectors.N; // Dimensions in a dense vector
public static final int LEN = K+1; // How many entries in a result vector
private final String wiki_vectors_location = "data/wiki-vectors.lmdb";
private Env wiki_vectors_env = new Env();
public MeanDVSearch(Environment env) {
super(env);
wiki_vectors_env.open(wiki_vectors_location, NOSUBDIR);
}
/**
* Pick the top K of M results, for tiny K (20) and huge M (like 50mil)
* @param sims Array of similarities (looking for max)
* @param names Array of utf8-names
* @param sim This similarity
* @param cursor The cursor to copy the name from, if necessary
*
* This is based on bubble sort, because we know nearly every entry will
* be worse than the worst of sims, and this is both simple and has nice
* best-case complexity.
*/
public static void bubble(double[] sims, byte[][] names, double this_sim, byte[] name, int K) {
// Bubble up the list as far as necessary
// Trick: the array is one longer than necessary
// That way there is no special case at the end.
int i = K-1;
for (; i>=0 && this_sim > sims[i]; i--) {
// Still percolating upward?
// Shift this entry down
sims[i+1] = sims[i];
// Shift the name too
names[i+1] = names[i];
}
// We passed it.
sims[i+1] = this_sim;
names[i+1] = name;//cursor.keyBytes();
}
/**
* Optimized _linear_search_ for the best N documents by cosine similarity.
* Be warned: This will be slow.
*/
public List<Passage> query(Question question) {
// Convert the question to a vector.
float[] query_vector = DenseVectors.mean(
question.memo(Phrase.simpleTokens)
.stream().map(DenseVectors::vectorFor)
.filter(v -> v.isPresent())
.map(v -> v.get())
.collect(Collectors.toList()));
// Now look for (almost) that vector!
// This is a little ugly because we desperately avoid copying.
byte[][] winners = new byte[LEN][];
double[] sims = new double[LEN];
/*try (Transaction tx = wiki_vectors_env.createReadTransaction();
Database doc_vectors = wiki_vectors_env.openDatabase(tx, "wiki-vectors", 0);
BufferCursor cursor = doc_vectors.bufferCursor(tx)) {
cursor.first();
while (cursor.next()) {
double this_sim = sim(query_vector, cursor);
bubble(sims, winners, this_sim, cursor);
}
}*/
try (Transaction tx = wiki_vectors_env.createReadTransaction();
Database doc_vectors = wiki_vectors_env.openDatabase(tx, "wiki-vectors", 0)) {
for (Entry e : doc_vectors.iterate(tx).iterable()) {
double this_sim = DenseVectors.sim(query_vector, KV.asVector(e.getValue()));
if (Double.isFinite(this_sim))
bubble(sims, winners, this_sim, e.getKey(), K);
}
}
// Now get the passages for the top entries.
List<Passage> passages = new ArrayList<>();
for (int i=0; i<K; i++) {
if (winners[i] != null) {
String id = string(winners[i]);
passages.add(new Passage("meandv", "", "", id));
System.out.println("value is : " + id + " sim: " + sims[i]);
}
}
/*try{
Process p = Runtime.getRuntime().exec("python /home/sean/yeshvant/top100vectorSimilarDocs.py " + query );
BufferedReader in = new BufferedReader(new InputStreamReader(p.getInputStream()));
String line = "";
while((line = in.readLine())!= null)
{
String[] sim_id = line.split(" ");
passages.add(new Passage("meandv", "", "", sim_id[1]));
System.out.println("value is : "+sim_id[1]);
}
} catch(Exception e) {
e.printStackTrace();
}*/
return fillFromSources(passages);
}
private static double sim(float[] left, float[] right) {
/*
* A.T * B
* -----------------------
* sqrt(A.T*A) sqrt(B.T*B)
*/
assert left.length == N;
// assert right.length == N; // You can't tell. Fingers crossed.
double ab = 0.0, aa = 0.0, bb = 0.0;
for (int i=0; i<left.length; i++) {
float f = right[i];
ab += left [i] * f;
aa += left [i] * left [i];
bb += f * f;
}
if (aa == 0.0 || bb == 0.0) return 0;
else return ab / (Math.sqrt(aa) * Math.sqrt(bb));
}
}