/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.classification;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.index.AtomicReader;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.StorableField;
import org.apache.lucene.index.StoredDocument;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IntsRef;
import org.apache.lucene.util.fst.Builder;
import org.apache.lucene.util.fst.FST;
import org.apache.lucene.util.fst.PositiveIntOutputs;
import org.apache.lucene.util.fst.Util;
import java.io.IOException;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
/**
* A perceptron (see <code>http://en.wikipedia.org/wiki/Perceptron</code>) based
* <code>Boolean</code> {@link org.apache.lucene.classification.Classifier}. The
* weights are calculated using
* {@link org.apache.lucene.index.TermsEnum#totalTermFreq} both on a per field
* and a per document basis and then a corresponding
* {@link org.apache.lucene.util.fst.FST} is used for class assignment.
*
* @lucene.experimental
*/
public class BooleanPerceptronClassifier implements Classifier<Boolean> {
private Double threshold;
private final Integer batchSize;
private Terms textTerms;
private Analyzer analyzer;
private String textFieldName;
private FST<Long> fst;
/**
* Create a {@link BooleanPerceptronClassifier}
*
* @param threshold
* the binary threshold for perceptron output evaluation
*/
public BooleanPerceptronClassifier(Double threshold, Integer batchSize) {
this.threshold = threshold;
this.batchSize = batchSize;
}
/**
* Default constructor, no batch updates of FST, perceptron threshold is
* calculated via underlying index metrics during
* {@link #train(org.apache.lucene.index.AtomicReader, String, String, org.apache.lucene.analysis.Analyzer)
* training}
*/
public BooleanPerceptronClassifier() {
batchSize = 1;
}
/**
* {@inheritDoc}
*/
@Override
public ClassificationResult<Boolean> assignClass(String text)
throws IOException {
if (textTerms == null) {
throw new IOException("You must first call Classifier#train");
}
Long output = 0l;
try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
CharTermAttribute charTermAttribute = tokenStream
.addAttribute(CharTermAttribute.class);
tokenStream.reset();
while (tokenStream.incrementToken()) {
String s = charTermAttribute.toString();
Long d = Util.get(fst, new BytesRef(s));
if (d != null) {
output += d;
}
}
tokenStream.end();
}
return new ClassificationResult<>(output >= threshold, output.doubleValue());
}
/**
* {@inheritDoc}
*/
@Override
public void train(AtomicReader atomicReader, String textFieldName,
String classFieldName, Analyzer analyzer) throws IOException {
train(atomicReader, textFieldName, classFieldName, analyzer, null);
}
/**
* {@inheritDoc}
*/
@Override
public void train(AtomicReader atomicReader, String textFieldName,
String classFieldName, Analyzer analyzer, Query query) throws IOException {
this.textTerms = MultiFields.getTerms(atomicReader, textFieldName);
if (textTerms == null) {
throw new IOException(new StringBuilder(
"term vectors need to be available for field ").append(textFieldName)
.toString());
}
this.analyzer = analyzer;
this.textFieldName = textFieldName;
if (threshold == null || threshold == 0d) {
// automatic assign a threshold
long sumDocFreq = atomicReader.getSumDocFreq(textFieldName);
if (sumDocFreq != -1) {
this.threshold = (double) sumDocFreq / 2d;
} else {
throw new IOException(
"threshold cannot be assigned since term vectors for field "
+ textFieldName + " do not exist");
}
}
// TODO : remove this map as soon as we have a writable FST
SortedMap<String,Double> weights = new TreeMap<>();
TermsEnum reuse = textTerms.iterator(null);
BytesRef textTerm;
while ((textTerm = reuse.next()) != null) {
weights.put(textTerm.utf8ToString(), (double) reuse.totalTermFreq());
}
updateFST(weights);
IndexSearcher indexSearcher = new IndexSearcher(atomicReader);
int batchCount = 0;
BooleanQuery q = new BooleanQuery();
q.add(new BooleanClause(new WildcardQuery(new Term(classFieldName, "*")), BooleanClause.Occur.MUST));
if (query != null) {
q.add(new BooleanClause(query, BooleanClause.Occur.MUST));
}
// run the search and use stored field values
for (ScoreDoc scoreDoc : indexSearcher.search(q,
Integer.MAX_VALUE).scoreDocs) {
StoredDocument doc = indexSearcher.doc(scoreDoc.doc);
// assign class to the doc
ClassificationResult<Boolean> classificationResult = assignClass(doc
.getField(textFieldName).stringValue());
Boolean assignedClass = classificationResult.getAssignedClass();
// get the expected result
StorableField field = doc.getField(classFieldName);
Boolean correctClass = Boolean.valueOf(field.stringValue());
long modifier = correctClass.compareTo(assignedClass);
if (modifier != 0) {
reuse = updateWeights(atomicReader, reuse, scoreDoc.doc, assignedClass,
weights, modifier, batchCount % batchSize == 0);
}
batchCount++;
}
weights.clear(); // free memory while waiting for GC
}
@Override
public void train(AtomicReader atomicReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
throw new IOException("training with multiple fields not supported by boolean perceptron classifier");
}
private TermsEnum updateWeights(AtomicReader atomicReader, TermsEnum reuse,
int docId, Boolean assignedClass, SortedMap<String,Double> weights,
double modifier, boolean updateFST) throws IOException {
TermsEnum cte = textTerms.iterator(reuse);
// get the doc term vectors
Terms terms = atomicReader.getTermVector(docId, textFieldName);
if (terms == null) {
throw new IOException("term vectors must be stored for field "
+ textFieldName);
}
TermsEnum termsEnum = terms.iterator(null);
BytesRef term;
while ((term = termsEnum.next()) != null) {
cte.seekExact(term);
if (assignedClass != null) {
long termFreqLocal = termsEnum.totalTermFreq();
// update weights
Long previousValue = Util.get(fst, term);
String termString = term.utf8ToString();
weights.put(termString, previousValue + modifier * termFreqLocal);
}
}
if (updateFST) {
updateFST(weights);
}
reuse = cte;
return reuse;
}
private void updateFST(SortedMap<String,Double> weights) throws IOException {
PositiveIntOutputs outputs = PositiveIntOutputs.getSingleton();
Builder<Long> fstBuilder = new Builder<>(FST.INPUT_TYPE.BYTE1, outputs);
BytesRef scratchBytes = new BytesRef();
IntsRef scratchInts = new IntsRef();
for (Map.Entry<String,Double> entry : weights.entrySet()) {
scratchBytes.copyChars(entry.getKey());
fstBuilder.add(Util.toIntsRef(scratchBytes, scratchInts), entry
.getValue().longValue());
}
fst = fstBuilder.finish();
}
}