/** * Copyright 2014, Emory University * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package edu.emory.clir.clearnlp.classification.model; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import java.io.BufferedInputStream; import java.io.BufferedOutputStream; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.junit.Test; import edu.emory.clir.clearnlp.classification.instance.IntInstance; import edu.emory.clir.clearnlp.classification.instance.StringInstance; import edu.emory.clir.clearnlp.classification.instance.StringInstanceReader; import edu.emory.clir.clearnlp.classification.prediction.StringPrediction; import edu.emory.clir.clearnlp.classification.vector.AbstractWeightVector; import edu.emory.clir.clearnlp.classification.vector.StringFeatureVector; import edu.emory.clir.clearnlp.util.IOUtils; /** * @since 3.0.0 * @author Jinho D. Choi ({@code jinho.choi@emory.edu}) */ public class StringModelTest { @Test public void testBinary() throws Exception { StringInstanceReader reader = new StringInstanceReader(IOUtils.createFileInputStream("src/test/resources/classification/model/binary-string.train")); List<StringInstance> instances = new ArrayList<>(); StringInstance instance; while ((instance = reader.next()) != null) instances.add(instance); reader.close(); StringModel model = new StringModel(true); AbstractWeightVector vector = model.getWeightVector(); for (StringInstance inst : instances) model.addInstance(inst); model.initializeForTraining(2, 1); assertEquals( 1, model.getLabelSize()); assertEquals( 4, model.getFeatureSize()); assertEquals( 4, vector.size()); assertTrue(vector.isBinaryLabel()); for (StringInstance inst : instances) model.addInstance(inst); List<IntInstance> list = model.initializeForTraining(0, 0); assertEquals( 2, model.getLabelSize()); assertEquals( 13, model.getFeatureSize()); assertEquals( 13, vector.size()); assertTrue(vector.isBinaryLabel()); String[] sparse = {"1 5 2 11", "0 6 2 10", "1 4 7 3", "0 1 9 12", "0 1 8 3"}; int i, size = sparse.length; for (i=0; i<size; i++) assertEquals(sparse[i], list.get(i).toString()); vector.set( 0, 0); vector.set( 1, 2); vector.set( 2, 0); vector.set( 3, -1); vector.set( 4, -1); vector.set( 5, -1); vector.set( 6, 1); vector.set( 7, -1); vector.set( 8, 1); vector.set( 9, 1); vector.set(10, 1); vector.set(11, -1); vector.set(12, 1); testBinaryAux(model); ByteArrayOutputStream bout = new ByteArrayOutputStream(); ObjectOutputStream out = new ObjectOutputStream(new BufferedOutputStream(bout)); model.save(out); out.close(); ObjectInputStream in = new ObjectInputStream(new BufferedInputStream(new ByteArrayInputStream(bout.toByteArray()))); model.load(in); in.close(); testBinaryAux(model); } private void testBinaryAux(StringModel model) { StringFeatureVector x0; StringPrediction p; double[] scores; x0 = new StringFeatureVector(); x0.addFeature(0, "jinho"); x0.addFeature(1, "martin"); x0.addFeature(2, "s"); scores = model.getScores(x0); assertEquals(scores[0], -1, 0); assertEquals(scores[1], 1, 0); p = model.predictBest(x0); assertEquals("male", p.getLabel()); assertEquals(1, p.getScore(), 0); x0 = new StringFeatureVector(true); x0.addFeature(0, "jinho", 2); x0.addFeature(1, "martin", 2); x0.addFeature(2, "s", 5); scores = model.getScores(x0); assertEquals(scores[0], 1, 0); assertEquals(scores[1], -1, 0); p = model.predictBest(x0); assertEquals("female", p.getLabel()); assertEquals(1, p.getScore(), 0); x0 = new StringFeatureVector(); x0.addFeature(1, "jinho"); x0.addFeature(0, "martin"); x0.addFeature(2, "s"); scores = model.getScores(x0); assertEquals(scores[0], 1, 0); assertEquals(scores[1], -1, 0); x0 = new StringFeatureVector(); x0.addFeature(1, "jinho"); x0.addFeature(0, "martin"); x0.addFeature(2, "h"); p = model.predictBest(x0); assertEquals("male", p.getLabel()); assertEquals(1, p.getScore(), 0); } @Test public void testMulti() throws Exception { StringInstanceReader reader = new StringInstanceReader(IOUtils.createFileInputStream("src/test/resources/classification/model/multi-string.train")); List<StringInstance> instances = new ArrayList<>(); StringInstance instance; while ((instance = reader.next()) != null) instances.add(instance); reader.close(); StringModel model = new StringModel(false); AbstractWeightVector vector = model.getWeightVector(); for (StringInstance inst : instances) model.addInstance(inst); model.initializeForTraining(1, 1); assertEquals(2, model.getLabelSize()); assertEquals(4, model.getFeatureSize()); assertEquals(8, vector.size()); assertFalse(vector.isBinaryLabel()); for (StringInstance inst : instances) model.addInstance(inst); List<IntInstance> list = model.initializeForTraining(0, 0); assertEquals( 3, model.getLabelSize()); assertEquals( 7, model.getFeatureSize()); assertEquals(21, vector.size()); String[] sparse = {"2 4 2 3", "0 1 5", "1 1 2", "0 3", "1 6"}; int i, size = sparse.length; for (i=0; i<size; i++) assertEquals(sparse[i], list.get(i).toString()); vector.set(vector.getWeightIndex(0, 1), 1); vector.set(vector.getWeightIndex(1, 1), 1); vector.set(vector.getWeightIndex(2, 1), -1); vector.set(vector.getWeightIndex(0, 2), -1); vector.set(vector.getWeightIndex(1, 2), 0); vector.set(vector.getWeightIndex(2, 2), 1); vector.set(vector.getWeightIndex(0, 3), 1); vector.set(vector.getWeightIndex(1, 3), -1); vector.set(vector.getWeightIndex(2, 3), 1); vector.set(vector.getWeightIndex(0, 4), -1); vector.set(vector.getWeightIndex(1, 4), -1); vector.set(vector.getWeightIndex(2, 4), 1); vector.set(vector.getWeightIndex(0, 5), 1); vector.set(vector.getWeightIndex(1, 5), -1); vector.set(vector.getWeightIndex(2, 5), -1); vector.set(vector.getWeightIndex(0, 6), -1); vector.set(vector.getWeightIndex(1, 6), 1); vector.set(vector.getWeightIndex(2, 6), -1); assertEquals("[0.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0]", Arrays.toString(vector.getWeights(0))); assertEquals("[0.0, 1.0, 0.0, -1.0, -1.0, -1.0, 1.0]", Arrays.toString(vector.getWeights(1))); assertEquals("[0.0, -1.0, 1.0, 1.0, 1.0, -1.0, -1.0]", Arrays.toString(vector.getWeights(2))); testMultiAux(model); vector.set(vector.getWeightIndex(0, 0), 0); vector.set(vector.getWeightIndex(1, 0), 0); vector.set(vector.getWeightIndex(2, 0), 0); ByteArrayOutputStream bout = new ByteArrayOutputStream(); ObjectOutputStream out = new ObjectOutputStream(new BufferedOutputStream(bout)); model.save(out); out.close(); ObjectInputStream in = new ObjectInputStream(new BufferedInputStream(new ByteArrayInputStream(bout.toByteArray()))); model.load(in); in.close(); testMultiAux(model); } private void testMultiAux(StringModel model) { AbstractWeightVector vector = model.getWeightVector(); StringPrediction[] p2; StringPrediction[] pl; StringFeatureVector x0; StringPrediction p; double[] scores; x0 = new StringFeatureVector(); x0.addFeature(0, "bright"); x0.addFeature(1, "dry"); x0.addFeature(2, "dark"); scores = model.getScores(x0); assertEquals(scores[0], -3, 0); assertEquals(scores[1], 0, 0); assertEquals(scores[2], 1, 0); p = model.predictBest(x0); assertEquals("sunny", p.getLabel()); assertEquals(1, p.getScore(), 0); p2 = model.predictTop2(x0); p = p2[0]; assertEquals("sunny", p.getLabel()); assertEquals(1, p.getScore(), 0); p = p2[1]; assertEquals("cloudy", p.getLabel()); assertEquals(0, p.getScore(), 0); pl = model.predictAll(x0); p = pl[0]; assertEquals("sunny", p.getLabel()); assertEquals(1, p.getScore(), 0); p = pl[1]; assertEquals("cloudy", p.getLabel()); assertEquals(0, p.getScore(), 0); p = pl[2]; assertEquals("rainy", p.getLabel()); assertEquals(-3, p.getScore(), 0); vector.add(vector.getWeightIndex(0, 0), 5); vector.add(vector.getWeightIndex(1, 0), 4); vector.add(vector.getWeightIndex(2, 0), 3); vector.add(vector.getWeightIndex(0, 0), 5); vector.add(vector.getWeightIndex(1, 0), 6); vector.add(vector.getWeightIndex(2, 0), 7); x0 = new StringFeatureVector(true); x0.addFeature(1, "bright", 2d); x0.addFeature(2, "dry" , 2d); x0.addFeature(0, "dark" , 2d); scores = model.getScores(x0); assertEquals(scores[0], 12, 0); assertEquals(scores[1], 12, 0); assertEquals(scores[2], 8, 0); p = model.predictBest(x0); assertEquals("rainy", p.getLabel()); assertEquals(12, p.getScore(), 0); p2 = model.predictTop2(x0); p = p2[0]; assertEquals("rainy", p.getLabel()); assertEquals(12, p.getScore(), 0); p = p2[1]; assertEquals("cloudy", p.getLabel()); assertEquals(12, p.getScore(), 0); pl = model.predictAll(x0); p = pl[0]; assertEquals("rainy", p.getLabel()); assertEquals(12, p.getScore(), 0); p = pl[1]; assertEquals("cloudy", p.getLabel()); assertEquals(12, p.getScore(), 0); p = pl[2]; assertEquals("sunny", p.getLabel()); assertEquals(8, p.getScore(), 0); } }