package opennlp.maxent.io;
import java.io.StringReader;
import junit.framework.TestCase;
import opennlp.maxent.GIS;
import opennlp.maxent.PlainTextByLineDataStream;
import opennlp.maxent.RealBasicEventStream;
import opennlp.model.EventStream;
import opennlp.model.MaxentModel;
import opennlp.model.OnePassRealValueDataIndexer;
import opennlp.model.RealValueFileEventStream;
import opennlp.model.TwoPassRealValueDataIndexer;
public class TwoPassRealValueDataIndexerTest extends TestCase {
/**
* This test sets out to prove that the scale you use on real valued predicates
* doesn't matter when it comes the probability assigned to each outcome.
* Strangely, if we use (1,2) and (10,20) there's no difference.
* If we use (0.1,0.2) and (10,20) there is a difference.
* @throws Exception
*/
public void testDataIndexers() throws Exception {
String smallValues = "predA=0.1 predB=0.2 A\n" +
"predB=0.3 predA=0.1 B\n";
String smallTest = "predA=0.2 predB=0.2";
StringReader smallReader = new StringReader(smallValues);
EventStream smallEventStream = new RealBasicEventStream(new PlainTextByLineDataStream(smallReader));
MaxentModel smallModel = GIS.trainModel(100, new OnePassRealValueDataIndexer(smallEventStream,0), false);
String[] contexts = smallTest.split(" ");
float[] values = RealValueFileEventStream.parseContexts(contexts);
double[] smallResults = smallModel.eval(contexts, values);
String smallResultString = smallModel.getAllOutcomes(smallResults);
System.out.println("smallResults: " + smallResultString);
StringReader smallReaderTwoPass = new StringReader(smallValues);
EventStream smallEventStreamTwoPass = new RealBasicEventStream(new PlainTextByLineDataStream(smallReaderTwoPass));
MaxentModel smallModelTwoPass = GIS.trainModel(100, new TwoPassRealValueDataIndexer(smallEventStreamTwoPass,0), false);
contexts = smallTest.split(" ");
values = RealValueFileEventStream.parseContexts(contexts);
double[] smallResultsTwoPass = smallModelTwoPass.eval(contexts, values);
String smallResultTwoPassString = smallModel.getAllOutcomes(smallResults);
System.out.println("smallResults2: " + smallResultTwoPassString);
assertEquals(smallResults.length, smallResultsTwoPass.length);
for(int i=0; i<smallResults.length; i++) {
System.out.println(String.format("classify with smallModel: %1$s = %2$f", smallModel.getOutcome(i), smallResults[i]));
System.out.println(String.format("classify with smallModelTwoPass: %1$s = %2$f", smallModelTwoPass.getOutcome(i), smallResultsTwoPass[i]));
assertEquals(smallResults[i], smallResultsTwoPass[i], 0.01f);
}
}
}