package com.yahoo.glimmer.query;
/*
* Copyright (c) 2012 Yahoo! Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
* You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software distributed under the License is
* distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and limitations under the License.
* See accompanying LICENSE file.
*/
import it.unimi.di.big.mg4j.index.Index;
import it.unimi.di.big.mg4j.search.DocumentIterator;
import it.unimi.di.big.mg4j.search.score.AbstractWeightedScorer;
import it.unimi.di.big.mg4j.search.score.BM25FScorer;
import it.unimi.di.big.mg4j.search.score.DelegatingScorer;
import it.unimi.di.big.mg4j.search.visitor.CounterCollectionVisitor;
import it.unimi.di.big.mg4j.search.visitor.CounterSetupVisitor;
import it.unimi.di.big.mg4j.search.visitor.TermCollectionVisitor;
import it.unimi.dsi.big.util.StringMap;
import it.unimi.dsi.fastutil.doubles.DoubleArrays;
import it.unimi.dsi.fastutil.ints.IntBigList;
import it.unimi.dsi.fastutil.longs.LongBigList;
import it.unimi.dsi.fastutil.objects.Object2DoubleMap;
import it.unimi.dsi.fastutil.objects.Reference2DoubleMap;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import org.apache.log4j.Logger;
public class WOOScorer extends AbstractWeightedScorer implements DelegatingScorer {
private static final Logger LOGGER = Logger.getLogger(BM25FScorer.class);
private static final boolean DEBUG = true;
/** The default value used for the parameter <var>k</var><sub>1</sub>. */
public final static double DEFAULT_K1 = 1.2;
/** The default value used for the parameter <var>b</var>. */
public final static double DEFAULT_B = 0.5;
/**
* The value of the document-frequency part for terms appearing in more than
* half of the documents.
*/
public final static double EPSILON_SCORE = 1E-6;
/** The counter collection visitor used to estimate counts. */
private final CounterCollectionVisitor counterCollectionVisitor;
/** The counter setup visitor used to estimate counts. */
private final CounterSetupVisitor setupVisitor;
/** The term collection visitor used to estimate counts. */
private final TermCollectionVisitor termVisitor;
/** The parameter <var>k</var><sub>1</sub>. */
public final double k1;
/** The parameter <var>b</var>; you must provide one value for each index. */
public final Reference2DoubleMap<Index> bByIndex;
/** The parameter {@link #k1} plus one, precomputed. */
private final double k1Plus1;
/**
* An array indexed by offsets that caches the inverse document-frequency
* part of the formula, multiplied by the index weight.
*/
private double[] idfPart;
/**
* An array indexed by offsets that caches the weight corresponding to each
* pair.
*/
private double[] offset2Weight;
/**
* An array indexed by offsets that gives the unique id of each term in the
* query.
*/
private int[] offset2TermId;
/** A term map to index {@link #frequencies}. */
private final StringMap<? extends CharSequence> termMap;
/**
* The list of virtual frequencies (possibly approximated using just the
* frequencies of the main field).
*/
private final LongBigList frequencies;
/**
* An array indexed by offsets mapping each offset to the corresponding
* index number.
*/
private int[] offset2Index;
/**
* An array indexed by term ids used by {@link #score()} to compute virtual
* counts.
*/
private double[] virtualCount;
/**
* For expected IDF runs, an array indexed by term ids used by
* {@link #score()} to compute virtual counts combined with IDF scoring.
*/
// private double[] virtualIdfCount;
private long N;
/**
* An array (parallel to {@link #currIndex}) used by {@link #score()} to
* cache the current document sizes.
*/
// private int[] size;
/** The weight of each index. */
private double[] weight;
/**
* An array indexed by offsets mapping each offset to the parameter
* <var>b</var> of the corresponding index.
*/
private double[] index2B;
private Object2DoubleMap<String> bByName;
private double w_numberOfFieldsMatched;
private IntBigList defaultSizes;
private double averageDocLength;
private double[] documentWeights;
private final double dl_cutoff;
private HashMap<Integer, Integer> documentPriors;
private int max_number_of_fields;
private int NEUTRAL = Integer.parseInt(SetDocumentPriors.NEUTRAL);
/**
*
* @param k1
* @param b
* @param termMap
* @param frequencies
* @param N
* number of documents
* @param w_numberOfFieldsMatched
*/
public WOOScorer(final double k1, final Reference2DoubleMap<Index> b, final StringMap<? extends CharSequence> termMap, final LongBigList frequencies,
final IntBigList defaultSizes, double averageDocLength, long N, double w_numberOfFieldsMatched, double[] documentWeights, double dl_cutoff,
HashMap<Integer, Integer> documentPriors, int max_number_of_fields) {
this.termMap = termMap;
termVisitor = new TermCollectionVisitor();
setupVisitor = new CounterSetupVisitor(termVisitor);
counterCollectionVisitor = new CounterCollectionVisitor(setupVisitor);
this.k1 = k1;
this.bByIndex = b;
this.frequencies = frequencies;
this.k1Plus1 = k1 + 1;
this.bByName = null;
this.w_numberOfFieldsMatched = w_numberOfFieldsMatched;
this.N = N;
this.defaultSizes = defaultSizes;
this.averageDocLength = averageDocLength;
this.documentWeights = documentWeights;
this.dl_cutoff = dl_cutoff;
this.max_number_of_fields = max_number_of_fields;
this.documentPriors = documentPriors;
}
public DelegatingScorer copy() {
final WOOScorer scorer = new WOOScorer(k1, bByIndex, termMap, frequencies, defaultSizes, averageDocLength, N, w_numberOfFieldsMatched, documentWeights,
dl_cutoff, documentPriors, max_number_of_fields);
scorer.setWeights(index2Weight);
return scorer;
}
public double score() throws IOException {
setupVisitor.clear();
documentIterator.acceptOnTruePaths(counterCollectionVisitor);
final long document = documentIterator.document();
final int[] count = setupVisitor.count;
final double[] offset2Weight = this.offset2Weight;
final int[] offset2TermId = this.offset2TermId;
final double[] idfPart = this.idfPart;
final double[] virtualCount = this.virtualCount;
// final double[] virtualIdfCount = this.virtualIdfCount;
final double[] index2B = this.index2B;
final double[] idf = new double[virtualCount.length];
double docLen = defaultSizes.getInt(document);
if (docLen < dl_cutoff)
docLen = dl_cutoff;
// Compute virtual size
int term2Index;
// int termId;
DoubleArrays.fill(virtualCount, 0);
double score = 0, v;
// System.out.println("Using b of "+index2B[ 0 ]);
// if ( termMap != null ) {
// int tmpTfCount = 0;
for (int i = offset2TermId.length; i-- != 0;) {
if (offset2TermId[i] == -1)
continue;
idf[offset2TermId[i]] = idfPart[i];
term2Index = offset2Index[i];
// virtualCount[ offset2TermId[ i ] ] += count[ i ] * offset2Weight[
// i ] / ( ( 1 - index2B[ term2Index ] ) + index2B[ term2Index ] *
// size[ term2Index ] / avgDocumentSize[ term2Index ] );
if (index2B[term2Index] != 1) {
virtualCount[offset2TermId[i]] += (count[i] * offset2Weight[i]) / ((1 - index2B[term2Index]) + index2B[term2Index] * docLen / averageDocLength);
} else {
virtualCount[offset2TermId[i]] += (count[i] * offset2Weight[i]);
}
// tmpTfCount += count[i];
// if(count[i]>0)
// System.out.println("docid="+documentIterator.document()+" idf="+idfPart[i]+" docLen="+defaultSizes.getInt(document)+" docLenUsed="+docLen+" avDocLen="+averageDocLength+" virtualCount="+virtualCount[
// offset2TermId[ i ]
// ]+" tf="+count[i]+" w="+offset2Weight[i]+" b="+index2B[
// term2Index ] +" offsetToTermId="+offset2TermId[ i ] );
// if(count[i]>0)
// System.out.println("docid="+documentIterator.document()+" idf="+idfPart[offset2TermId[
// i
// ]]+" docLen="+defaultSizes.getInt(document)+" docLenUsed="+docLen+" avDocLen="+averageDocLength+" virtualCount="+virtualCount[
// offset2TermId[ i ]
// ]+" tf="+count[i]+" w="+offset2Weight[i]+" b="+index2B[
// term2Index ] +" offsetToTermId="+offset2TermId[ i ] );
}
double numberOfFieldsMatched = 0;
for (int i = virtualCount.length; i-- != 0;) {
v = virtualCount[i];
double idft = idf[i];
// System.out.println("i="+i+" k1plus1="+k1Plus1+" v="+v+" k1="+k1+" idfPart="+idft);
// System.out.println( (( k1Plus1 * v ) / ( v + k1 )) * idfPart[ i
// ]+" oldScore "+score);
score += (k1Plus1 * v) / (v + k1) * idft;
// System.out.println("new score="+score);
if (virtualCount[i] > 0)
numberOfFieldsMatched++;
}
numberOfFieldsMatched = (numberOfFieldsMatched > max_number_of_fields) ? max_number_of_fields : numberOfFieldsMatched;
score *= w_numberOfFieldsMatched * numberOfFieldsMatched / max_number_of_fields;
// System.out.println("w_match="+w_numberOfFieldsMatched+" number_matched="+numberOfFieldsMatched+" max fields "+max_number_of_fields);
// score += w_numberOfFieldsMatched * numberOfFieldsMatched;
// System.out.println("Adding "+w_numberOfFieldsMatched+"*"+numberOfFieldsMatched);
// documentPriors
if (documentPriors != null) {
Integer doccategory = documentPriors.get(document);
if (doccategory != null)
score *= documentWeights[doccategory];
else
score *= documentWeights[NEUTRAL];
/*
* if (doccategory != null) {
* System.out.println("Adding (found) prior of category "
* +doccategory+" weight is "+ documentWeights[doccategory]);
* System.out.println("The weights are "+documentWeights);
* if(doccategory == Integer.parseInt(SetDocumentPriors.IMPORTANT))
* System.exit(-1); } //else
* System.out.println("Adding (not found) prior of "
* +documentWeights[NEUTRAL]);
*/
}
// System.out.println("docid="+documentIterator.document()+" docLen="+defaultSizes.getInt(document)+" final score="+score+" tf="+tmpTfCount);
return score;
}
public double score(final Index index) {
throw new UnsupportedOperationException();
}
public void wrap(DocumentIterator documentIterator) throws IOException {
super.wrap(documentIterator);
if (index2Weight.keySet().contains(null)) {
throw new IllegalStateException("index2Weights contains a null key!!");
}
termVisitor.prepare(index2Weight.keySet());
if (DEBUG)
LOGGER.debug("Weight map: " + index2Weight);
documentIterator.accept(termVisitor);
if (DEBUG)
LOGGER.debug("Term Visitor found " + termVisitor.numberOfPairs() + " leaves");
final Index[] index = termVisitor.indices();
if (DEBUG)
LOGGER.debug("Indices: " + Arrays.toString(index));
if (!index2Weight.keySet().containsAll(Arrays.asList(index)))
throw new IllegalArgumentException("A WOOScorer scorer must have a weight for all indices involved in a query");
for (Index i : index) {
if (bByIndex != null && !bByIndex.containsKey(i) || bByName != null && !bByName.containsKey(i.field)) {
throw new IllegalArgumentException("A WOOScorer scorer must have a b parameter for all indices involved in a query " + i);
}
}
setupVisitor.prepare();
documentIterator.accept(setupVisitor);
weight = new double[index.length];
for (int i = weight.length; i-- != 0;) {
weight[i] = index2Weight.getDouble(index[i]);
}
offset2TermId = setupVisitor.offset2TermId;
offset2Index = setupVisitor.indexNumber;
offset2Weight = new double[offset2Index.length];
index2B = new double[index.length];
for (int i = 0; i < index2B.length; i++)
index2B[i] = bByIndex != null ? bByIndex.getDouble(index[i]) : bByName.getDouble(index[i].field);
for (int i = offset2Weight.length; i-- != 0;) {
offset2Weight[i] = index2Weight.getDouble(index[offset2Index[i]]) * index2Weight.size();
}
// We do all logs here
idfPart = new double[termVisitor.numberOfPairs()];
for (int i = idfPart.length; i-- != 0;) {
if (setupVisitor.offset2TermId[i] == -1)
continue;
// TODO CAUTION ATOMIC BOMB
final int id = (int) termMap.getLong(setupVisitor.termId2Term[setupVisitor.offset2TermId[i]]);
/*
* if ( id == -1 ) throw new IllegalStateException(
* "The term map passed to a WOOScorer scorer must contain all terms appearing in all indices"
* ); final long f = frequencies.getLong( id ); idfPart[ i ] =
* Math.max( EPSILON_SCORE, Math.log( ( N - f + 0.5 ) / ( f + 0.5 )
* ) );
*/
if (id == -1) {
idfPart[i] = 0; // if the final score is not a X * idf the score
// for unseen terms will not be zero!!!!!!!!!
} else {
final long f = frequencies.getLong(id);
idfPart[i] = Math.max(EPSILON_SCORE, Math.log((N - f + 0.5) / (f + 0.5)));
}
// System.out.println("i="+i+", frequency is "+f+" N is "+N);
// System.out.println("idf for term "+setupVisitor.termId2Term[
// setupVisitor.offset2TermId[ i ] ]+" is "+idfPart[i]);
}
virtualCount = new double[setupVisitor.termId2Term.length];
// if (termMap == null) {
// virtualIdfCount = new double[setupVisitor.termId2Term.length];
// }
}
public boolean usesIntervals() {
return false;
}
}