/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.ml.modelinput; import org.apache.lucene.index.Fields; import org.apache.lucene.index.PostingsEnum; import org.apache.lucene.index.Term; import org.apache.lucene.index.Terms; import org.apache.lucene.index.TermsEnum; import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.search.lookup.IndexField; import org.elasticsearch.search.lookup.IndexFieldTerm; import org.elasticsearch.search.lookup.LeafDocLookup; import org.elasticsearch.search.lookup.LeafIndexLookup; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; /** * Elasticsearch Data Source */ public abstract class EsDataSource implements DataSource { protected abstract LeafDocLookup getDocLookup(); protected abstract LeafIndexLookup getLeafIndexLookup(); @SuppressWarnings("unchecked") @Override public <T> List<T> getValues(String field) { return (List<T>) getDocLookup().get(field); } @Override public double[] getOccurrenceDense(String[] terms, String field) { return getDense(terms, field, (indexField, indexFieldTerm) -> indexFieldTerm.tf() > 0 ? 1 : 0); } @Override public double[] getTfIdfDense(String[] terms, String field) { return getDense(terms, field, (indexField, indexFieldTerm) -> { double tf = indexFieldTerm.tf(); double df = indexFieldTerm.df(); double numDocs = indexField.docCount(); return tf * Math.log((numDocs + 1) / (df + 1)); }); } @Override public double[] getTfDense(String[] terms, String field) { return getDense(terms, field, (indexField, indexFieldTerm) -> indexFieldTerm.tf()); } @Override public Tuple<int[], double[]> getTfSparse(Map<String, Integer> wordMap, String field) { return getSparse(wordMap, field, (docsEnum, term) -> (double) docsEnum.freq()); } @Override public Tuple<int[], double[]> getTfIdfSparse(Map<String, Integer> wordMap, String field) { return getSparse(wordMap, field, (docsEnum, term) -> { double docFreq = getLeafIndexLookup().getParentReader().docFreq(new Term(field, term)); double freq = docsEnum.freq(); double docCount = getLeafIndexLookup().getParentReader().numDocs(); return freq * Math.log((docCount + 1) / (docFreq + 1)); }); } private interface IndexFieldTermFunction { double apply(IndexField indexField, IndexFieldTerm indexFieldTerm) throws IOException; } private double[] getDense(String[] terms, String field, IndexFieldTermFunction f) { double[] values = new double[terms.length]; IndexField indexField = getLeafIndexLookup().get(field); for (int i = 0; i < terms.length; i++) { IndexFieldTerm indexTermField = indexField.get(terms[i]); try { values[i] = f.apply(indexField, indexTermField); } catch (IOException ex) { throw new IllegalArgumentException("cannot get dense vector for field " + field + " for term "+ terms[i], ex); } } return values; } private interface DocsEnumFunction { double apply(PostingsEnum docsEnum, BytesRef term) throws IOException; } private Tuple<int[], double[]> getSparse(Map<String, Integer> wordMap, String field, DocsEnumFunction function) { try { Fields fields = getLeafIndexLookup().termVectors(); if (fields == null) { return null; } else { List<Integer> indices = new ArrayList<>(); List<Double> values = new ArrayList<>(); Terms terms = fields.terms(field); TermsEnum termsEnum = terms.iterator(); BytesRef t; PostingsEnum docsEnum = null; int numTerms = 0; indices.clear(); values.clear(); while ((t = termsEnum.next()) != null) { Integer termIndex = wordMap.get(t.utf8ToString()); if (termIndex != null) { indices.add(termIndex); docsEnum = termsEnum.postings(docsEnum, PostingsEnum.FREQS); int nextDoc = docsEnum.nextDoc(); assert nextDoc != PostingsEnum.NO_MORE_DOCS; values.add(function.apply(docsEnum, t)); nextDoc = docsEnum.nextDoc(); assert nextDoc == PostingsEnum.NO_MORE_DOCS; numTerms++; } } int[] indicesArray = new int[numTerms]; double[] valuesArray = new double[numTerms]; for (int i = 0; i < numTerms; i++) { indicesArray[i] = indices.get(i); valuesArray[i] = values.get(i); } return new Tuple<>(indicesArray, valuesArray); } } catch (IOException ex) { throw new IllegalArgumentException("cannot get sparse tf/idf vector for field "+ field, ex); } } }