/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package opennlp.tools.ml.maxent;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.junit.Assert;
import org.junit.Test;
import opennlp.tools.ml.AbstractEventTrainer;
import opennlp.tools.ml.AbstractTrainer;
import opennlp.tools.ml.EventTrainer;
import opennlp.tools.ml.TrainerFactory;
import opennlp.tools.ml.maxent.quasinewton.QNTrainer;
import opennlp.tools.ml.model.AbstractDataIndexer;
import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.DataIndexerFactory;
import opennlp.tools.ml.model.Event;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.ObjectStreamUtils;
import opennlp.tools.util.TrainingParameters;
import opennlp.tools.util.model.ModelUtil;
public class GISIndexingTest {
private static String[][] cntx = new String[][]{
{"dog","cat","mouse"},
{"text", "print", "mouse"},
{"dog", "pig", "cat", "mouse"}
};
private static String[] outputs = new String[]{"A","B","A"};
private ObjectStream<Event> createEventStream() {
List<Event> events = new ArrayList<>();
for (int i = 0; i < cntx.length; i++) {
events.add(new Event(outputs[i], cntx[i]));
}
return ObjectStreamUtils.createObjectStream(events);
}
/*
* Test the GIS.trainModel(ObjectStream<Event> eventStream) method
*/
@Test
public void testGISTrainSignature1() throws IOException {
try (ObjectStream<Event> eventStream = createEventStream()) {
TrainingParameters params = ModelUtil.createDefaultTrainingParameters();
params.put(AbstractTrainer.CUTOFF_PARAM, 1);
EventTrainer trainer = TrainerFactory.getEventTrainer(params, null);
Assert.assertNotNull(trainer.train(eventStream));
}
}
/*
* Test the GIS.trainModel(ObjectStream<Event> eventStream,boolean smoothing) method
*/
@Test
public void testGISTrainSignature2() throws IOException {
try (ObjectStream<Event> eventStream = createEventStream()) {
TrainingParameters params = ModelUtil.createDefaultTrainingParameters();
params.put(AbstractTrainer.CUTOFF_PARAM, 1);
params.put("smoothing", true);
EventTrainer trainer = TrainerFactory.getEventTrainer(params, null);
Assert.assertNotNull(trainer.train(eventStream));
}
}
/*
* Test the GIS.trainModel(ObjectStream<Event> eventStream, int iterations, int cutoff) method
*/
@Test
public void testGISTrainSignature3() throws IOException {
try (ObjectStream<Event> eventStream = createEventStream()) {
TrainingParameters params = ModelUtil.createDefaultTrainingParameters();
params.put(AbstractTrainer.ITERATIONS_PARAM, 10);
params.put(AbstractTrainer.CUTOFF_PARAM, 1);
EventTrainer trainer = TrainerFactory.getEventTrainer(params, null);
Assert.assertNotNull(trainer.train(eventStream));
}
}
/*
* Test the GIS.trainModel(ObjectStream<Event> eventStream, int iterations, int cutoff, double sigma) method
*/
@Test
public void testGISTrainSignature4() throws IOException {
try (ObjectStream<Event> eventStream = createEventStream()) {
TrainingParameters params = ModelUtil.createDefaultTrainingParameters();
params.put(AbstractTrainer.ITERATIONS_PARAM, 10);
params.put(AbstractTrainer.CUTOFF_PARAM, 1);
GISTrainer trainer = (GISTrainer) TrainerFactory.getEventTrainer(params, null);
trainer.setGaussianSigma(0.01);
Assert.assertNotNull(trainer.trainModel(eventStream));
}
}
/*
* Test the GIS.trainModel((ObjectStream<Event> eventStream, int iterations, int cutoff,
* boolean smoothing, boolean printMessagesWhileTraining)) method
*/
@Test
public void testGISTrainSignature5() throws IOException {
try (ObjectStream<Event> eventStream = createEventStream()) {
TrainingParameters params = ModelUtil.createDefaultTrainingParameters();
params.put(AbstractTrainer.ITERATIONS_PARAM, 10);
params.put(AbstractTrainer.CUTOFF_PARAM, 1);
params.put("smoothing", false);
params.put(AbstractTrainer.VERBOSE_PARAM, false);
EventTrainer trainer = TrainerFactory.getEventTrainer(params, null);
Assert.assertNotNull(trainer.train(eventStream));
}
}
@Test
public void testIndexingWithTrainingParameters() throws IOException {
ObjectStream<Event> eventStream = createEventStream();
TrainingParameters parameters = TrainingParameters.defaultParams();
// by default we are using GIS/EventTrainer/Cutoff of 5/100 iterations
parameters.put(TrainingParameters.ITERATIONS_PARAM, 10);
parameters.put(AbstractEventTrainer.DATA_INDEXER_PARAM, AbstractEventTrainer.DATA_INDEXER_ONE_PASS_VALUE);
parameters.put(AbstractEventTrainer.CUTOFF_PARAM, 1);
// note: setting the SORT_PARAM to true is the default, so it is not really needed
parameters.put(AbstractDataIndexer.SORT_PARAM, true);
// guarantee that you have a GIS trainer...
EventTrainer trainer =
TrainerFactory.getEventTrainer(parameters, new HashMap<>());
Assert.assertEquals("opennlp.tools.ml.maxent.GISTrainer", trainer.getClass().getName());
AbstractEventTrainer aeTrainer = (AbstractEventTrainer)trainer;
// guarantee that you have a OnePassDataIndexer ...
DataIndexer di = aeTrainer.getDataIndexer(eventStream);
Assert.assertEquals("opennlp.tools.ml.model.OnePassDataIndexer", di.getClass().getName());
Assert.assertEquals(3, di.getNumEvents());
Assert.assertEquals(2, di.getOutcomeLabels().length);
Assert.assertEquals(6, di.getPredLabels().length);
// change the parameters and try again...
eventStream.reset();
parameters.put(TrainingParameters.ALGORITHM_PARAM, QNTrainer.MAXENT_QN_VALUE);
parameters.put(AbstractEventTrainer.DATA_INDEXER_PARAM, AbstractEventTrainer.DATA_INDEXER_TWO_PASS_VALUE);
parameters.put(AbstractEventTrainer.CUTOFF_PARAM, 2);
trainer = TrainerFactory.getEventTrainer(parameters, new HashMap<>());
Assert.assertEquals("opennlp.tools.ml.maxent.quasinewton.QNTrainer", trainer.getClass().getName());
aeTrainer = (AbstractEventTrainer)trainer;
di = aeTrainer.getDataIndexer(eventStream);
Assert.assertEquals("opennlp.tools.ml.model.TwoPassDataIndexer", di.getClass().getName());
eventStream.close();
}
@Test
public void testIndexingFactory() throws IOException {
Map<String,String> myReportMap = new HashMap<>();
ObjectStream<Event> eventStream = createEventStream();
// set the cutoff to 1 for this test.
TrainingParameters parameters = new TrainingParameters();
parameters.put(AbstractDataIndexer.CUTOFF_PARAM, 1);
// test with a 1 pass data indexer...
parameters.put(AbstractEventTrainer.DATA_INDEXER_PARAM, AbstractEventTrainer.DATA_INDEXER_ONE_PASS_VALUE);
DataIndexer di = DataIndexerFactory.getDataIndexer(parameters, myReportMap);
Assert.assertEquals("opennlp.tools.ml.model.OnePassDataIndexer", di.getClass().getName());
di.index(eventStream);
Assert.assertEquals(3, di.getNumEvents());
Assert.assertEquals(2, di.getOutcomeLabels().length);
Assert.assertEquals(6, di.getPredLabels().length);
eventStream.reset();
// test with a 2-pass data indexer...
parameters.put(AbstractEventTrainer.DATA_INDEXER_PARAM, AbstractEventTrainer.DATA_INDEXER_TWO_PASS_VALUE);
di = DataIndexerFactory.getDataIndexer(parameters, myReportMap);
Assert.assertEquals("opennlp.tools.ml.model.TwoPassDataIndexer", di.getClass().getName());
di.index(eventStream);
Assert.assertEquals(3, di.getNumEvents());
Assert.assertEquals(2, di.getOutcomeLabels().length);
Assert.assertEquals(6, di.getPredLabels().length);
// the rest of the test doesn't actually index, so we can close the eventstream.
eventStream.close();
// test with a 1-pass Real value dataIndexer
parameters.put(AbstractEventTrainer.DATA_INDEXER_PARAM,
AbstractEventTrainer.DATA_INDEXER_ONE_PASS_REAL_VALUE);
di = DataIndexerFactory.getDataIndexer(parameters, myReportMap);
Assert.assertEquals("opennlp.tools.ml.model.OnePassRealValueDataIndexer", di.getClass().getName());
// test with an UNRegistered MockIndexer
parameters.put(AbstractEventTrainer.DATA_INDEXER_PARAM, "opennlp.tools.ml.maxent.MockDataIndexer");
di = DataIndexerFactory.getDataIndexer(parameters, myReportMap);
Assert.assertEquals("opennlp.tools.ml.maxent.MockDataIndexer", di.getClass().getName());
}
}