package org.wikibrain.matrix.knn; /* * A min-heap that tracks the n closest neighbors. * Each element in the neighborhood has a score and an id. */ public class NeighborhoodAccumulator { private double[] similarities; private int[] keys; private int size; /** * Create a neighborhood accumulator that holds at most n elements. * @param n */ public NeighborhoodAccumulator(int n) { similarities = new double[n+1]; keys = new int[n+1]; size = 0 ; keys[0] = Integer.MIN_VALUE; similarities[0] = Double.NEGATIVE_INFINITY; } /** * Possibly add a neighbor to the neighborhood. * @param key * @param sim Similarity of the neighbor. */ public final void visit(int key, double sim) { if (size < similarities.length - 1) { insert(key, sim); } else if (sim > similarities[1]) { assert(size == similarities.length - 1); removeMin(); insert(key, sim); } } public Neighborhood get() { int ids[] = new int[size]; double scores[] = new double[size]; for (int i = 1; i <= size; i++) { ids[i - 1] = keys[i]; scores[i - 1] =similarities[i]; } quickSort(ids, scores, 0, ids.length - 1); return new Neighborhood(ids, scores); } private int leftChild(int pos) { return 2*pos; } private int rightChild(int pos) { return 2*pos + 1; } private int parent(int pos) { return pos / 2; } private boolean isLeaf(int pos) { return ((pos > size/2) && (pos <= size)); } private void swap(int pos1, int pos2) { double tmpVal; tmpVal = similarities[pos1]; similarities[pos1] = similarities[pos2]; similarities[pos2] = tmpVal; int tmpKey; tmpKey = keys[pos1]; keys[pos1] = keys[pos2]; keys[pos2] = tmpKey; } private void insert(int key, double value) { assert(size < similarities.length - 1); size++; keys[size] = key; similarities[size] = value; int current = size; while (similarities[current] < similarities[parent(current)]) { swap(current, parent(current)); current = parent(current); } } private int minKey() { return keys[1]; } private double minValue() { return similarities[1]; } private void removeMin() { swap(1,size); size--; if (size != 0) pushDown(1); } private void pushDown(int position) { int smallestChild; while (!isLeaf(position)) { smallestChild = leftChild(position); if ((smallestChild < size) && (similarities[smallestChild] > similarities[smallestChild+1])) smallestChild = smallestChild + 1; if (similarities[position] <= similarities[smallestChild]) return; swap(position,smallestChild); position = smallestChild; } } // Adapted from http://www.programcreek.com/2012/11/quicksort-array-in-java/ private void quickSort(int colIds[], double colVals[], int low, int high) { if (colIds.length == 0 || low >= high) return; // pick the pivot int middle = (low + high) / 2; double pivot = colVals[middle]; // partition around the pivot int i = low, j = high; while (i <= j) { while (colVals[i] > pivot) { i++; } while (colVals[j] < pivot) { j--; } if (i <= j) { int temp = colIds[i]; double tempV = colVals[i]; colIds[i] = colIds[j]; colVals[i] = colVals[j]; colIds[j] = temp; colVals[j] = tempV; i++; j--; } } //recursively sort two sub parts quickSort(colIds, colVals, low, j); quickSort(colIds, colVals, i, high); } }