package com.hkorte.elasticsearch.significance; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; import com.hkorte.elasticsearch.significance.measures.*; import com.hkorte.elasticsearch.significance.model.ScoredTerm; import org.apache.lucene.util.PriorityQueue; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.client.Client; import org.elasticsearch.common.logging.ESLogger; import org.elasticsearch.common.logging.Loggers; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.rest.BytesRestResponse; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.search.facet.FacetBuilders; import org.elasticsearch.search.facet.terms.TermsFacet; import java.io.IOException; import java.util.*; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import static org.elasticsearch.rest.RestStatus.OK; /** * Created by hkorte on 25.04.14. */ public class SignificantTermsProvider { private static final String FACET_NAME = "terms"; private static final int NUM_TERMS = 100000; private final ESLogger logger; private final Client client; private final LoadingCache<FieldIdentifier, Map<String, Integer>> globalDocFreqCache; private final LoadingCache<TypeIdentifier, Long> globalDocCountCache; public SignificantTermsProvider(Settings settings, Client client) { this.logger = Loggers.getLogger(getClass(), settings); this.client = client; this.globalDocFreqCache = CacheBuilder.newBuilder().maximumSize(5).expireAfterWrite(5, TimeUnit.MINUTES).build(new CacheLoader<FieldIdentifier, Map<String, Integer>>() { @Override public Map<String, Integer> load(FieldIdentifier fieldIdentifier) { SearchResponse response = SignificantTermsProvider.this.client.prepareSearch(fieldIdentifier.indices) .setTypes(fieldIdentifier.types).setSize(0).setQuery(QueryBuilders.matchAllQuery()).addFacet (FacetBuilders.termsFacet(FACET_NAME).order(TermsFacet.ComparatorType.COUNT).field (fieldIdentifier.field).size(NUM_TERMS)).get(); TermsFacet termsFacet = response.getFacets().facet(FACET_NAME); Map<String, Integer> map = new HashMap<String, Integer>(); for (TermsFacet.Entry entry : termsFacet) { map.put(entry.getTerm().string(), entry.getCount()); } return map; } }); this.globalDocCountCache = CacheBuilder.newBuilder().maximumSize(5).expireAfterWrite(5, TimeUnit.MINUTES).build(new CacheLoader<TypeIdentifier, Long>() { @Override public Long load(TypeIdentifier fieldIdentifier) { SearchResponse response = SignificantTermsProvider.this.client.prepareSearch(fieldIdentifier.indices) .setTypes(fieldIdentifier.types).setSize(0).setQuery(QueryBuilders.matchAllQuery()).get(); return response.getHits().getTotalHits(); } }); } public void writeSignificantTerms(RestChannel channel, String[] indices, String[] types, String field, int size, String query) throws IOException, ExecutionException { Map<String, Integer> dfMap = this.globalDocFreqCache.get(new FieldIdentifier(indices, types, field)); long numDocs = this.globalDocCountCache.get(new TypeIdentifier(indices, types)); SearchResponse response = client.prepareSearch(indices).setTypes(types).setSize(0).setQuery(query).addFacet (FacetBuilders.termsFacet(FACET_NAME).order(TermsFacet.ComparatorType.COUNT).field(field).size (NUM_TERMS)).get(); long numHits = response.getHits().getTotalHits(); if (numHits > numDocs) { // obviously the numDocs were outdated // -> we simply adjust the value to be valid numDocs = numHits; } // we only have sensible results, if there is a negative set if (numHits < numDocs) { TermsFacet termsFacet = response.getFacets().facet(FACET_NAME); List<SignificanceMeasureResults> resultsList = new ArrayList<>(); resultsList.add(new SignificanceMeasureResults(new DefaultHeuristic(), size)); resultsList.add(new SignificanceMeasureResults(new MutualInformation(), size)); resultsList.add(new SignificanceMeasureResults(new ChiSquared(), size)); resultsList.add(new SignificanceMeasureResults(new KullbackLeiblerDivergence(), size)); for (TermsFacet.Entry entry : termsFacet) { String term = entry.getTerm().string(); if (dfMap.containsKey(term)) { int termDf = dfMap.get(term); // #docs which contain word with positive class long n11 = entry.getCount(); // #docs which contain word with any class long n1X = termDf; // if the dfMap does not contain the word or is outdated, we give it at least the search results // frequency if (n1X < n11) { n1X = n11; } // #docs which contain word with negative class long n10 = n1X - n11; // #docs which do not contain word with positive class long n01 = numHits - n11; // #docs which do not contain word with any class long n0X = numDocs - n1X; // #docs which do not contain word with negative class long n00 = n0X - n01; // if the stats are outdated and the amount of recently added documents containing the current // word // is above average, the value of numDocs may be too small, so that n00 becomes negative if (n00 < 0) { n00 = 0; } double relativePositiveFreq = (1.0 + n11) / (1.0 + numHits); double relativeGlobalFreq = (1.0 + n1X) / (1.0 + numDocs); // Only consider positive deviation if ((relativePositiveFreq / relativeGlobalFreq) > 1.0) { for (SignificanceMeasureResults significanceMeasureResults : resultsList) { significanceMeasureResults.update(term, n00, n01, n10, n11); } } } } XContentBuilder builder = channel.newBuilder(); builder.startObject(); for (SignificanceMeasureResults significanceMeasureResults : resultsList) { builder.startArray(significanceMeasureResults.significanceMeasure.shortName()); for (ScoredTerm scoredTerm : reverse(significanceMeasureResults.priorityQueue)) { builder.startObject(); builder.field("term", scoredTerm.getTerm()); builder.field("score", scoredTerm.getScore()); builder.field("n00", scoredTerm.getN00()); builder.field("n01", scoredTerm.getN01()); builder.field("n10", scoredTerm.getN10()); builder.field("n11", scoredTerm.getN11()); scoredTerm.addCustomFields(builder); builder.endObject(); } builder.endArray(); } builder.endObject(); channel.sendResponse(new BytesRestResponse(OK, builder)); } } //TODO: I'm sure there is a more elegant way to get desc ordering private <T> List<T> reverse(PriorityQueue<T> queue) { LinkedList<T> list = new LinkedList<>(); T obj; while ((obj = queue.pop()) != null) { list.addFirst(obj); } return list; } private class SignificanceMeasureResults { private SignificanceMeasure significanceMeasure; private ScoredTermPriorityQueue priorityQueue; private SignificanceMeasureResults(SignificanceMeasure significanceMeasure, int size) { this.significanceMeasure = significanceMeasure; this.priorityQueue = new ScoredTermPriorityQueue(size); } public void update(String term, long n00, long n01, long n10, long n11) { ScoredTerm scoredTerm = this.significanceMeasure.apply(n00, n01, n10, n11); scoredTerm.setTerm(term); this.priorityQueue.insertWithOverflow(scoredTerm); } } private class ScoredTermPriorityQueue extends PriorityQueue<ScoredTerm> { public ScoredTermPriorityQueue(int maxSize) { super(maxSize); } @Override protected boolean lessThan(ScoredTerm a, ScoredTerm b) { return a.getScore() - b.getScore() < 0.0; } } private class TypeIdentifier { private String[] indices; private String[] types; private TypeIdentifier(String[] indices, String[] types) { this.indices = indices; this.types = types; } @Override public boolean equals(Object o) { if (this == o) { return true; } if (o == null || getClass() != o.getClass()) { return false; } TypeIdentifier that = (TypeIdentifier) o; if (!Arrays.equals(indices, that.indices)) { return false; } if (!Arrays.equals(types, that.types)) { return false; } return true; } @Override public int hashCode() { int result = indices != null ? Arrays.hashCode(indices) : 0; result = 31 * result + (types != null ? Arrays.hashCode(types) : 0); return result; } } private class FieldIdentifier { private String[] indices; private String[] types; private String field; private FieldIdentifier(String[] indices, String[] types, String field) { this.indices = indices; this.types = types; this.field = field; } @Override public boolean equals(Object o) { if (this == o) { return true; } if (o == null || getClass() != o.getClass()) { return false; } FieldIdentifier that = (FieldIdentifier) o; if (field != null ? !field.equals(that.field) : that.field != null) { return false; } if (!Arrays.equals(indices, that.indices)) { return false; } if (!Arrays.equals(types, that.types)) { return false; } return true; } @Override public int hashCode() { int result = indices != null ? Arrays.hashCode(indices) : 0; result = 31 * result + (types != null ? Arrays.hashCode(types) : 0); result = 31 * result + (field != null ? field.hashCode() : 0); return result; } } }