/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.classification;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.AtomicReader;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.SlowCompositeReaderWrapper;
import org.apache.lucene.search.Query;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.TestUtil;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.util.Random;
/**
* Base class for testing {@link Classifier}s
*/
public abstract class ClassificationTestBase<T> extends LuceneTestCase {
public final static String POLITICS_INPUT = "Here are some interesting questions and answers about Mitt Romney.. " +
"If you don't know the answer to the question about Mitt Romney, then simply click on the answer below the question section.";
public static final BytesRef POLITICS_RESULT = new BytesRef("politics");
public static final String TECHNOLOGY_INPUT = "Much is made of what the likes of Facebook, Google and Apple know about users." +
" Truth is, Amazon may know more.";
public static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology");
private RandomIndexWriter indexWriter;
private Directory dir;
private FieldType ft;
String textFieldName;
String categoryFieldName;
String booleanFieldName;
@Override
@Before
public void setUp() throws Exception {
super.setUp();
dir = newDirectory();
indexWriter = new RandomIndexWriter(random(), dir);
textFieldName = "text";
categoryFieldName = "cat";
booleanFieldName = "bool";
ft = new FieldType(TextField.TYPE_STORED);
ft.setStoreTermVectors(true);
ft.setStoreTermVectorOffsets(true);
ft.setStoreTermVectorPositions(true);
}
@Override
@After
public void tearDown() throws Exception {
super.tearDown();
indexWriter.close();
dir.close();
}
protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception {
checkCorrectClassification(classifier, inputDoc, expectedResult, analyzer, textFieldName, classFieldName, null);
}
protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query) throws Exception {
AtomicReader atomicReader = null;
try {
populateSampleIndex(analyzer);
atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
classifier.train(atomicReader, textFieldName, classFieldName, analyzer, query);
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
assertNotNull(classificationResult.getAssignedClass());
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0);
} finally {
if (atomicReader != null)
atomicReader.close();
}
}
protected void checkOnlineClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception {
checkOnlineClassification(classifier, inputDoc, expectedResult, analyzer, textFieldName, classFieldName, null);
}
protected void checkOnlineClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query) throws Exception {
AtomicReader atomicReader = null;
try {
populateSampleIndex(analyzer);
atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
classifier.train(atomicReader, textFieldName, classFieldName, analyzer, query);
ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
assertNotNull(classificationResult.getAssignedClass());
assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0);
updateSampleIndex(analyzer);
ClassificationResult<T> secondClassificationResult = classifier.assignClass(inputDoc);
assertEquals(classificationResult.getAssignedClass(), secondClassificationResult.getAssignedClass());
assertEquals(Double.valueOf(classificationResult.getScore()), Double.valueOf(secondClassificationResult.getScore()));
} finally {
if (atomicReader != null)
atomicReader.close();
}
}
private void populateSampleIndex(Analyzer analyzer) throws IOException {
indexWriter.deleteAll();
indexWriter.commit();
String text;
Document doc = new Document();
text = "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " +
"who attempted to ask questions of the Republican presidential candidate in a public plaza near the Tomb of " +
"the Unknown Soldier in Warsaw Tuesday.";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft));
doc.add(new Field(booleanFieldName, "true", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
text = "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" +
" States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft));
doc.add(new Field(booleanFieldName, "true", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
text = "And there's a threshold question that he has to answer for the American people and " +
"that's whether he is prepared to be commander-in-chief,\" she continued. \"As we look to the past events, we " +
"know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft));
doc.add(new Field(booleanFieldName, "true", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
text = "Still, when it comes to gun policy, many congressional Democrats have \"decided to " +
"keep quiet and not go there,\" said Alan Lizotte, dean and professor at the State University of New York at " +
"Albany's School of Criminal Justice.";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft));
doc.add(new Field(booleanFieldName, "true", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
text = "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " +
"technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " +
"world through the Internet.";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "technology", ft));
doc.add(new Field(booleanFieldName, "false", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
text = "So, about all those experts and analysts who've spent the past year or so saying " +
"Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "technology", ft));
doc.add(new Field(booleanFieldName, "false", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
text = "More than 400 million people trust Google with their e-mail, and 50 million store files" +
" in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " +
"generally transfer or store huge volumes of personal data online.";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "technology", ft));
doc.add(new Field(booleanFieldName, "false", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
text = "unlabeled doc";
doc.add(new Field(textFieldName, text, ft));
indexWriter.addDocument(doc, analyzer);
indexWriter.commit();
}
protected void checkPerformance(Classifier<T> classifier, Analyzer analyzer, String classFieldName) throws Exception {
AtomicReader atomicReader = null;
long trainStart = System.currentTimeMillis();
try {
populatePerformanceIndex(analyzer);
atomicReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
classifier.train(atomicReader, textFieldName, classFieldName, analyzer);
long trainEnd = System.currentTimeMillis();
long trainTime = trainEnd - trainStart;
assertTrue("training took more than 2 mins : " + trainTime / 1000 + "s", trainTime < 120000);
} finally {
if (atomicReader != null)
atomicReader.close();
}
}
private void populatePerformanceIndex(Analyzer analyzer) throws IOException {
indexWriter.deleteAll();
indexWriter.commit();
FieldType ft = new FieldType(TextField.TYPE_STORED);
ft.setStoreTermVectors(true);
ft.setStoreTermVectorOffsets(true);
ft.setStoreTermVectorPositions(true);
int docs = 1000;
Random random = random();
for (int i = 0; i < docs; i++) {
boolean b = random.nextBoolean();
Document doc = new Document();
doc.add(new Field(textFieldName, createRandomString(random), ft));
doc.add(new Field(categoryFieldName, b ? "technology" : "politics", ft));
doc.add(new Field(booleanFieldName, String.valueOf(b), ft));
indexWriter.addDocument(doc, analyzer);
}
indexWriter.commit();
}
private String createRandomString(Random random) {
StringBuilder builder = new StringBuilder();
for (int i = 0; i < 20; i++) {
builder.append(TestUtil.randomSimpleString(random, 5));
builder.append(" ");
}
return builder.toString();
}
private void updateSampleIndex(Analyzer analyzer) throws Exception {
String text;
Document doc = new Document();
text = "Warren Bennis says John F. Kennedy grasped a key lesson about the presidency that few have followed.";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft));
doc.add(new Field(booleanFieldName, "true", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
text = "Julian Zelizer says Bill Clinton is still trying to shape his party, years after the White House, while George W. Bush opts for a much more passive role.";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft));
doc.add(new Field(booleanFieldName, "true", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
text = "Crossfire: Sen. Tim Scott passes on Sen. Lindsey Graham endorsement";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft));
doc.add(new Field(booleanFieldName, "true", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
text = "Illinois becomes 16th state to allow same-sex marriage.";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "politics", ft));
doc.add(new Field(booleanFieldName, "true", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
text = "Apple is developing iPhones with curved-glass screens and enhanced sensors that detect different levels of pressure, according to a new report.";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "technology", ft));
doc.add(new Field(booleanFieldName, "false", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
text = "The Xbox One is Microsoft's first new gaming console in eight years. It's a quality piece of hardware but it's also noteworthy because Microsoft is using it to make a statement.";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "technology", ft));
doc.add(new Field(booleanFieldName, "false", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
text = "Google says it will replace a Google Maps image after a California father complained it shows the body of his teen-age son, who was shot to death in 2009.";
doc.add(new Field(textFieldName, text, ft));
doc.add(new Field(categoryFieldName, "technology", ft));
doc.add(new Field(booleanFieldName, "false", ft));
indexWriter.addDocument(doc, analyzer);
doc = new Document();
text = "second unlabeled doc";
doc.add(new Field(textFieldName, text, ft));
indexWriter.addDocument(doc, analyzer);
indexWriter.commit();
}
}