package org.wikibrain.sr; import edu.emory.mathcs.backport.java.util.Collections; import gnu.trove.map.TIntFloatMap; import gnu.trove.map.hash.TIntFloatHashMap; import org.apache.commons.collections.iterators.ArrayIterator; import java.util.Arrays; import java.util.Comparator; import java.util.Iterator; import java.util.List; public class SRResultList implements Iterable<SRResult>{ private SRResult[] results; private int numDocs; private double missingScore; // score for missing documents. private int missingRank; private float[] scores; // performance optimization public SRResultList(int maxNumDocs) { this.results = new SRResult[maxNumDocs]; for (int i = 0; i < this.results.length; i++) { results[i] = new SRResult(); } numDocs = maxNumDocs; } public double minScore() { return (numDocs == 0) ? 0.0 : this.results[numDocs-1].getScore(); } public double maxScore() { return (numDocs == 0) ? 0.0 : this.results[0].getScore(); } /** * Returns the specified number of docs in this list. * Unless a call to truncate has been made, this will be * the max number of documents specified in the constructor. * @return */ public int numDocs() { return numDocs; } /** * Truncates the list to the specified size. * @param numDocs */ public void truncate(int numDocs) { assert(numDocs <= results.length); this.numDocs = numDocs; } /** * Returns the index of the specified ID, or -1 if not found. * @param id * @return */ public int getIndexForId(int id) { for (int i = 0; i < numDocs(); i++) { if (results[i].id == id) { return i; } } return -1; } /** * Returns the score for the specified ID, or Double.NaN if not found. * @param id * @return */ public double getScoreForId(int id) { for (int i = 0; i < numDocs(); i++) { if (results[i].id == id) { return results[i].getScore(); } } return Double.NaN; } /** * Returns the ID of the specified index. * @param i * @return */ public int getId(int i) { assert(i < numDocs); return results[i].id; } public void setId(int i, int id) { results[i].id = id; } /** * Returns an array of the IDs in this list. * @return */ public int[] getIds() { int[] ids = new int[numDocs]; for (int i = 0; i < numDocs; i++) { ids[i] = results[i].id; } return ids; } /** * Returns the score of the specified index. * @param i * @return */ public double getScore(int i) { assert(i < numDocs); return results[i].getScore(); } /** * Returns an array of scores in this list. * @return */ public double[] getScores() { double[] scores = new double[numDocs]; for (int i = 0; i < numDocs; i++) { scores[i] = results[i].score; } return scores; } /** * Returns an array of scores in this list as float values. * @return */ public float[] getScoresAsFloat() { if (scores == null) { scores = new float[numDocs]; for (int i = 0; i < numDocs; i++) { scores[i] = (float) results[i].getScore(); } } return scores; } /** * Sets the ID and score of the SRResult at the index. * Note that this does not affect that result's explanations. * @param i * @param id * @param score */ public void set(int i, int id, double score) { assert(i < numDocs); results[i].id = id; results[i].score = score; } /** * Sets the SRResut at the index to the new SRResult. * @param i * @param result */ public void set(int i, SRResult result){ assert(i < numDocs); results[i] = result; } /** * Sets the ID, score, and explanations of the SRResult at the index. * @param i * @param id * @param score * @param explanationList */ public void set(int i, int id, double score, List<Explanation> explanationList){ set(i, new SRResult(id, score, explanationList)); } /** * Returns this list as a TIntFloatMap. * Note that this does not maintain any order. * @return */ public TIntFloatMap asTroveMap() { TIntFloatHashMap map = new TIntFloatHashMap(); for (int i = 0; i < numDocs; i++) { map.put(results[i].id, (float) results[i].getScore()); } return map; } /** * Normalizes the score vector of this list to a unit length. */ public void makeUnitLength() { double length = 0.0; for (int i = 0; i < numDocs; i++) { double x = results[i].getScore(); length += x * x; } if (length != 0) { length = Math.sqrt(length); for (int i = 0; i < numDocs; i++) { results[i].score /= length; } } } public String toString() { StringBuilder builder = new StringBuilder(); for (int i = 0; i < numDocs(); i++) { if (i > 0) builder.append(" "); builder.append( String.format("%d. %d=%.3f", (i+1), results[i].getId(), results[i].getScore()) ); } return builder.toString(); } /** * Sorts the SRResults in this list in ascending order. */ public void sortAscending() { Arrays.sort(results, 0, numDocs); } /** * Sorts the SRResults in this list in descending order. */ public void sortDescending() { Arrays.sort(results, 0, numDocs, Collections.reverseOrder()); } /** * Sorts by id, ascending. */ public void sortById() { Arrays.sort(results, 0, numDocs, new Comparator<SRResult>() { @Override public int compare(SRResult o1, SRResult o2) { return o1.getId() - o2.getId(); } }); } /** * Returns the SRResult at the specified index. * @param i * @return */ public SRResult get(int i) { return results[i]; } /** * Returns the estimated similarity score for missing documents. * @return */ public double getMissingScore() { return missingScore; } /** * Sets the estimated similarity score for missing documents * @param missingScore */ public void setMissingScore(double missingScore) { this.missingScore = missingScore; } public void setMissingRank(int missingRank) { this.missingRank = missingRank; } @Override public Iterator<SRResult> iterator() { return new ArrayIterator(results, 0, numDocs); } }