package org.wikibrain.matrix.knn; import gnu.trove.list.TIntList; import gnu.trove.list.array.TIntArrayList; import gnu.trove.set.TIntSet; import gnu.trove.set.hash.TIntHashSet; import org.wikibrain.matrix.DenseMatrix; import org.wikibrain.matrix.DenseMatrixRow; import java.io.File; import java.io.IOException; import java.util.*; /** * A fast neighborhood finder for dense vectors. * * @author Shilad Sen */ public class KDTreeKNN implements KNNFinder { private final DenseMatrix matrix; private final int[] allIds; private final int dimensions; private int maxSampleSize = 5000; private int maxLeaf = 100; List<float []> centroids; List<int []> members; public KDTreeKNN(DenseMatrix matrix) throws IOException { this.matrix = matrix; this.allIds = matrix.getRowIds(); this.dimensions = matrix.getRow(allIds[0]).getNumCols(); } @Override public void build() throws IOException { Node root = new Node("R"); // shuffle ids to ensure random partition intialization forever root.memberIds = new int[allIds.length]; System.arraycopy(allIds, 0, root.memberIds, 0, allIds.length); shuffle(root.memberIds); centroids = new ArrayList<float[]>(); members = new ArrayList<int[]>(); build(root); } private void build(Node node) throws IOException { if (node.memberIds.length < maxLeaf) { centroids.add(node.centroid); members.add(node.memberIds); return; } double [] laccum = new double[dimensions]; double [] raccum = new double[dimensions]; node.left = new Node(node.path + "L"); node.right = new Node(node.path + "R"); node.left.centroid = new float[dimensions]; node.right.centroid = new float[dimensions]; int n = Math.min(node.memberIds.length, maxSampleSize); // Calculate centroids int lcount = 0; int rcount = 0; for (int iter = 0; iter < 5; iter++) { lcount = 0; rcount = 0; Arrays.fill(laccum, 0.0); Arrays.fill(raccum, 0.0); double obj = 0.0; for (int m = 0; m < n; m++) { DenseMatrixRow row = matrix.getRow(node.memberIds[m]); double lsim; double rsim; if (iter == 0) { lsim = (m < n/2) ? 1.0 : 0.0; rsim = 1.0 - lsim; } else { lsim = row.dot(node.left.centroid); rsim = row.dot(node.right.centroid); } if (lsim >= rsim) { for (int j = 0; j < dimensions; j++) { laccum[j] += row.getColValue(j); } lcount++; } else { for (int j = 0; j < dimensions; j++) { raccum[j] += row.getColValue(j); } rcount++; } obj += Math.max(lsim, rsim); } obj = (iter == 0) ? 0.0 : obj / n; // System.out.format("Node %s iter=%d obj=%.3f left-size=%d right-size=%d\n", // node.path, iter, obj, lcount, rcount); normalize(laccum); normalize(raccum); for (int i = 0; i < dimensions; i++) node.left.centroid[i] = (float) laccum[i]; for (int i = 0; i < dimensions; i++) node.right.centroid[i] = (float) raccum[i]; } // Final placement TIntList leftIds = new TIntArrayList(); TIntList rightIds = new TIntArrayList(); for (int id : node.memberIds) { DenseMatrixRow row = matrix.getRow(id); double lsim = row.dot(node.left.centroid); double rsim = row.dot(node.right.centroid); if (lsim >= rsim) { leftIds.add(id); } else { rightIds.add(id); } } node.left.memberIds = leftIds.toArray(); node.right.memberIds = rightIds.toArray(); if (node.left.memberIds.length + node.right.memberIds.length != node.memberIds.length) { throw new IllegalStateException(); } // Recurse build(node.left); build(node.right); } private static class Candidate implements Comparable<Candidate> { final int clusterNum; final double score; public Candidate(int clusterNum, double score) { this.clusterNum = clusterNum; this.score = score; } @Override public int compareTo(Candidate o) { return Double.compare(score, o.score); } } @Override public Neighborhood query(float[] vector, int k, int maxTraversal, TIntSet validIds) { TreeSet<Candidate> clusters = new TreeSet<Candidate>(); for (int i = 0; i < centroids.size(); i++) { clusters.add(new Candidate(i, dot(centroids.get(i), vector))); } NeighborhoodAccumulator accum = new NeighborhoodAccumulator(k); int traversed = 0; while (!clusters.isEmpty()) { int clusterNum = clusters.pollLast().clusterNum; for (int rowId : members.get(clusterNum)) { if (validIds != null && !validIds.contains(rowId)) continue; DenseMatrixRow row = null; try { row = matrix.getRow(rowId); } catch (IOException e) { throw new IllegalStateException(e); } double sim = cosine(vector, row); accum.visit(row.getRowIndex(), sim); traversed++; } if (traversed >= maxTraversal) { break; } } return accum.get(); } @Override public void save(File path) throws IOException { throw new UnsupportedOperationException(); } @Override public boolean load(File path) throws IOException { throw new UnsupportedOperationException(); } public void setMaxSampleSize(int sampleSize) { this.maxSampleSize = sampleSize; } public void setMaxLeaf(int maxLeaf) { this.maxLeaf = maxLeaf; } static class Node { String path; float [] centroid; Node left; Node right; int [] memberIds; public Node(String path) { this.path = path; } } static double cosine(DenseMatrixRow X, DenseMatrixRow Y) { if (X == null || Y == null) { return 0; } return X.dot(Y); } static double cosine(float [] X, DenseMatrixRow Y) { if (X == null || Y == null) { return 0; } return Y.dot(X); } private static void shuffle(int [] array) { Random rand = new Random(); for (int i = array.length - 1; i > 0; i--) { int index = rand.nextInt(i + 1); // Simple swap int a = array[index]; array[index] = array[i]; array[i] = a; } } private double dot(float [] v1, float [] v2) { double sum = 0.0; for (int i = 0; i < v1.length; i++) { sum += v1[i] * v2[i]; } return sum; } private static void normalize(double [] X) { double norm = 0.0; for (int i = 0; i < X.length; i++) norm += X[i] * X[i]; norm = Math.sqrt(norm) + 0.00001; for (int i = 0; i < X.length; i++) X[i] /= norm; } }