/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package opennlp.tools.util;
import java.io.ByteArrayInputStream;
import java.util.HashMap;
import java.util.Map;
import org.junit.Assert;
import org.junit.Test;
import opennlp.tools.ml.EventTrainer;
public class TrainingParametersTest {
@Test
public void testConstructors() throws Exception {
TrainingParameters tp1 =
new TrainingParameters(build("key1=val1,key2=val2,key3=val3"));
TrainingParameters tp2 = new TrainingParameters(
new ByteArrayInputStream("key1=val1\nkey2=val2\nkey3=val3\n".getBytes())
);
TrainingParameters tp3 = new TrainingParameters(tp2);
assertEquals(tp1, tp2);
assertEquals(tp2, tp3);
}
@Test
public void testDefault() {
TrainingParameters tr = TrainingParameters.defaultParams();
Assert.assertEquals(4, tr.getSettings().size());
Assert.assertEquals("MAXENT", tr.algorithm());
Assert.assertEquals(EventTrainer.EVENT_VALUE,
tr.getStringParameter(TrainingParameters.TRAINER_TYPE_PARAM,
"v11")); // use different defaults
Assert.assertEquals(100,
tr.getIntParameter(TrainingParameters.ITERATIONS_PARAM,
200)); // use different defaults
Assert.assertEquals(5,
tr.getIntParameter(TrainingParameters.CUTOFF_PARAM,
200)); // use different defaults
}
@Test
public void testGetAlgorithm() {
TrainingParameters tp = build("Algorithm=Perceptron,n1.Algorithm=SVM");
Assert.assertEquals("Perceptron", tp.algorithm());
Assert.assertEquals("SVM", tp.algorithm("n1"));
}
@Test
public void testGetSettings() {
TrainingParameters tp = build("k1=v1,n1.k2=v2,n2.k3=v3,n1.k4=v4");
assertEquals(buildMap("k1=v1"), tp.getSettings());
assertEquals(buildMap("k2=v2,k4=v4"), tp.getSettings("n1"));
assertEquals(buildMap("k3=v3"), tp.getSettings("n2"));
Assert.assertTrue(tp.getSettings("n3").isEmpty());
}
@Test
public void testGetParameters() {
TrainingParameters tp = build("k1=v1,n1.k2=v2,n2.k3=v3,n1.k4=v4");
assertEquals(build("k1=v1"), tp.getParameters(null));
assertEquals(build("k2=v2,k4=v4"), tp.getParameters("n1"));
assertEquals(build("k3=v3"), tp.getParameters("n2"));
Assert.assertTrue(tp.getParameters("n3").getSettings().isEmpty());
}
@Test
public void testPutGet() {
TrainingParameters tp =
build("k1=v1,int.k2=123,str.k2=v3,str.k3=v4,boolean.k4=false,double.k5=123.45,k21=234.5");
Assert.assertEquals("v1", tp.getStringParameter("k1", "def"));
Assert.assertEquals("def", tp.getStringParameter("k2", "def"));
Assert.assertEquals("v3", tp.getStringParameter("str", "k2", "def"));
Assert.assertEquals("def", tp.getStringParameter("str", "k4", "def"));
Assert.assertEquals(-100, tp.getIntParameter("k11", -100));
tp.put("k11", 234);
Assert.assertEquals(234, tp.getIntParameter("k11", -100));
Assert.assertEquals(123, tp.getIntParameter("int", "k2", -100));
Assert.assertEquals(-100, tp.getIntParameter("int", "k4", -100));
Assert.assertEquals(234.5, tp.getDoubleParameter("k21", -100), 0.001);
tp.put("k21", 345.6);
Assert.assertEquals(345.6, tp.getDoubleParameter("k21", -100), 0.001); // should be changed
tp.putIfAbsent("k21", 456.7);
Assert.assertEquals(345.6, tp.getDoubleParameter("k21", -100), 0.001); // should be unchanged
Assert.assertEquals(123.45, tp.getDoubleParameter("double", "k5", -100), 0.001);
Assert.assertEquals(true, tp.getBooleanParameter("k31", true));
tp.put("k31", false);
Assert.assertEquals(false, tp.getBooleanParameter("k31", true));
Assert.assertEquals(false, tp.getBooleanParameter("boolean", "k4", true));
}
// format: k1=v1,k2=v2,...
private static Map<String, String> buildMap(String str) {
String[] pairs = str.split(",");
Map<String, String> map = new HashMap<>(pairs.length);
for (String pair : pairs) {
String[] keyValue = pair.split("=");
map.put(keyValue[0], keyValue[1]);
}
return map;
}
// format: k1=v1,k2=v2,...
private static TrainingParameters build(String str) {
return new TrainingParameters(buildMap(str));
}
private static void assertEquals(Map<String, String> map1, Map<String, String> map2) {
Assert.assertNotNull(map1);
Assert.assertNotNull(map2);
Assert.assertEquals(map1.size(), map2.size());
for (String key : map1.keySet()) {
Assert.assertEquals(map1.get(key), map2.get(key));
}
}
private static void assertEquals(Map<String, String> map, TrainingParameters actual) {
Assert.assertNotNull(actual);
assertEquals(map, actual.getSettings());
}
private static void assertEquals(TrainingParameters expected, TrainingParameters actual) {
if (expected == null) {
Assert.assertNull(actual);
} else {
assertEquals(expected.getSettings(), actual);
}
}
}