/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.lucene.classification; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.index.AtomicReader; import org.apache.lucene.index.Term; import org.apache.lucene.queries.mlt.MoreLikeThis; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.WildcardQuery; import org.apache.lucene.util.BytesRef; import java.io.IOException; import java.io.StringReader; import java.util.HashMap; import java.util.Map; /** * A k-Nearest Neighbor classifier (see <code>http://en.wikipedia.org/wiki/K-nearest_neighbors</code>) based * on {@link MoreLikeThis} * * @lucene.experimental */ public class KNearestNeighborClassifier implements Classifier<BytesRef> { private MoreLikeThis mlt; private String[] textFieldNames; private String classFieldName; private IndexSearcher indexSearcher; private final int k; private Query query; private int minDocsFreq; private int minTermFreq; /** * Create a {@link Classifier} using kNN algorithm * * @param k the number of neighbors to analyze as an <code>int</code> */ public KNearestNeighborClassifier(int k) { this.k = k; } /** * Create a {@link Classifier} using kNN algorithm * * @param k the number of neighbors to analyze as an <code>int</code> * @param minDocsFreq the minimum number of docs frequency for MLT to be set with {@link MoreLikeThis#setMinDocFreq(int)} * @param minTermFreq the minimum number of term frequency for MLT to be set with {@link MoreLikeThis#setMinTermFreq(int)} */ public KNearestNeighborClassifier(int k, int minDocsFreq, int minTermFreq) { this.k = k; this.minDocsFreq = minDocsFreq; this.minTermFreq = minTermFreq; } /** * {@inheritDoc} */ @Override public ClassificationResult<BytesRef> assignClass(String text) throws IOException { if (mlt == null) { throw new IOException("You must first call Classifier#train"); } BooleanQuery mltQuery = new BooleanQuery(); for (String textFieldName : textFieldNames) { mltQuery.add(new BooleanClause(mlt.like(new StringReader(text), textFieldName), BooleanClause.Occur.SHOULD)); } Query classFieldQuery = new WildcardQuery(new Term(classFieldName, "*")); mltQuery.add(new BooleanClause(classFieldQuery, BooleanClause.Occur.MUST)); if (query != null) { mltQuery.add(query, BooleanClause.Occur.MUST); } TopDocs topDocs = indexSearcher.search(mltQuery, k); return selectClassFromNeighbors(topDocs); } private ClassificationResult<BytesRef> selectClassFromNeighbors(TopDocs topDocs) throws IOException { // TODO : improve the nearest neighbor selection Map<BytesRef, Integer> classCounts = new HashMap<>(); for (ScoreDoc scoreDoc : topDocs.scoreDocs) { BytesRef cl = new BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue()); Integer count = classCounts.get(cl); if (count != null) { classCounts.put(cl, count + 1); } else { classCounts.put(cl, 1); } } double max = 0; BytesRef assignedClass = new BytesRef(); for (Map.Entry<BytesRef, Integer> entry : classCounts.entrySet()) { Integer count = entry.getValue(); if (count > max) { max = count; assignedClass = entry.getKey().clone(); } } double score = max / (double) k; return new ClassificationResult<>(assignedClass, score); } /** * {@inheritDoc} */ @Override public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException { train(atomicReader, textFieldName, classFieldName, analyzer, null); } /** * {@inheritDoc} */ @Override public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException { train(atomicReader, new String[]{textFieldName}, classFieldName, analyzer, query); } /** * {@inheritDoc} */ @Override public void train(AtomicReader atomicReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException { this.textFieldNames = textFieldNames; this.classFieldName = classFieldName; mlt = new MoreLikeThis(atomicReader); mlt.setAnalyzer(analyzer); mlt.setFieldNames(textFieldNames); indexSearcher = new IndexSearcher(atomicReader); if (minDocsFreq > 0) { mlt.setMinDocFreq(minDocsFreq); } if (minTermFreq > 0) { mlt.setMinTermFreq(minTermFreq); } this.query = query; } }