/* * Copyright 2016 * Ubiquitous Knowledge Processing (UKP) Lab * Technische Universität Darmstadt * <p> * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * <p> * http://www.apache.org/licenses/LICENSE-2.0 * <p> * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package de.tudarmstadt.ukp.dkpro.core.mallet.wordembeddings; import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token; import de.tudarmstadt.ukp.dkpro.core.io.text.TextReader; import de.tudarmstadt.ukp.dkpro.core.mallet.type.WordEmbedding; import de.tudarmstadt.ukp.dkpro.core.tokit.BreakIteratorSegmenter; import org.apache.uima.UIMAException; import org.apache.uima.analysis_engine.AnalysisEngineDescription; import org.apache.uima.collection.CollectionReaderDescription; import org.apache.uima.fit.pipeline.SimplePipeline; import org.apache.uima.jcas.JCas; import org.apache.uima.resource.ResourceInitializationException; import org.dkpro.core.api.embeddings.VectorizerUtils; import org.junit.Before; import org.junit.Test; import java.io.File; import java.io.IOException; import java.net.URISyntaxException; import java.util.Arrays; import static org.apache.uima.fit.factory.AnalysisEngineFactory.createEngineDescription; import static org.apache.uima.fit.factory.CollectionReaderFactory.createReaderDescription; import static org.apache.uima.fit.util.JCasUtil.select; import static org.apache.uima.fit.util.JCasUtil.selectCovered; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; public class MalletEmbeddingsAnnotatorTest { private File modelFile; private File binaryModelFile; private static final String TXT_DIR = "src/test/resources/txt"; private static final String TXT_FILE_PATTERN = "[+]*.txt"; @Before public void setUp() throws URISyntaxException { modelFile = new File(getClass().getResource("/dummy.vec").toURI()); binaryModelFile = new File(getClass().getResource("/dummy.binary").toURI()); } @Test public void test() throws ResourceInitializationException { // tag::example[] CollectionReaderDescription reader = createReaderDescription(TextReader.class, TextReader.PARAM_SOURCE_LOCATION, TXT_DIR, TextReader.PARAM_PATTERNS, TXT_FILE_PATTERN, TextReader.PARAM_LANGUAGE, "en"); AnalysisEngineDescription segmenter = createEngineDescription(BreakIteratorSegmenter.class); AnalysisEngineDescription inferencer = createEngineDescription( MalletEmbeddingsAnnotator.class, MalletEmbeddingsAnnotator.PARAM_MODEL_LOCATION, modelFile); //end::example[] testEmbeddingAnnotations(reader, segmenter, inferencer); } @Test public void testBinary() throws ResourceInitializationException { CollectionReaderDescription reader = createReaderDescription(TextReader.class, TextReader.PARAM_SOURCE_LOCATION, TXT_DIR, TextReader.PARAM_PATTERNS, TXT_FILE_PATTERN, TextReader.PARAM_LANGUAGE, "en"); AnalysisEngineDescription segmenter = createEngineDescription(BreakIteratorSegmenter.class); AnalysisEngineDescription inferencer = createEngineDescription( MalletEmbeddingsAnnotator.class, MalletEmbeddingsAnnotator.PARAM_MODEL_LOCATION, binaryModelFile, MalletEmbeddingsAnnotator.PARAM_MODEL_IS_BINARY, true); testEmbeddingAnnotations(reader, segmenter, inferencer); } @Test public void testUnknownTokensText() throws ResourceInitializationException { int dim = 50; float[] unkVector = VectorizerUtils.randomVector(dim); int minTokenLength = 3; // minimum token length in test vector file CollectionReaderDescription reader = createReaderDescription(TextReader.class, TextReader.PARAM_SOURCE_LOCATION, TXT_DIR, TextReader.PARAM_PATTERNS, TXT_FILE_PATTERN, TextReader.PARAM_LANGUAGE, "en"); AnalysisEngineDescription segmenter = createEngineDescription(BreakIteratorSegmenter.class); AnalysisEngineDescription inferencer = createEngineDescription( MalletEmbeddingsAnnotator.class, MalletEmbeddingsAnnotator.PARAM_MODEL_LOCATION, modelFile, MalletEmbeddingsAnnotator.PARAM_ANNOTATE_UNKNOWN_TOKENS, true); for (JCas jcas : SimplePipeline.iteratePipeline(reader, segmenter, inferencer)) { for (Token token : select(jcas, Token.class)) { if (token.getCoveredText().length() < minTokenLength) { float[] vector = selectCovered(WordEmbedding.class, token).get(0) .getWordEmbedding() .toArray(); assertTrue(Arrays.equals(vector, unkVector)); } } } } @Test public void testUnknownTokensTextRandom() throws ResourceInitializationException { int dim = 50; int minTokenLength = 3; // minimum token length in test vector file CollectionReaderDescription reader = createReaderDescription(TextReader.class, TextReader.PARAM_SOURCE_LOCATION, TXT_DIR, TextReader.PARAM_PATTERNS, TXT_FILE_PATTERN, TextReader.PARAM_LANGUAGE, "en"); AnalysisEngineDescription segmenter = createEngineDescription(BreakIteratorSegmenter.class); AnalysisEngineDescription inferencer = createEngineDescription( MalletEmbeddingsAnnotator.class, MalletEmbeddingsAnnotator.PARAM_MODEL_LOCATION, modelFile, MalletEmbeddingsAnnotator.PARAM_ANNOTATE_UNKNOWN_TOKENS, true); float[] randomVector = null; boolean isFirst = true; for (JCas jcas : SimplePipeline.iteratePipeline(reader, segmenter, inferencer)) { for (Token token : select(jcas, Token.class)) { if (token.getCoveredText().length() < minTokenLength) { /* token should be unknown */ float[] vector = selectCovered(WordEmbedding.class, token).get(0) .getWordEmbedding() .toArray(); assertEquals(dim, vector.length); if (isFirst) { randomVector = vector.clone(); isFirst = false; } else { assertTrue(Arrays.equals(vector, randomVector)); } } } } } @Test(expected = ResourceInitializationException.class) public void testLowercaseCaseless() throws UIMAException, IOException { CollectionReaderDescription reader = createReaderDescription(TextReader.class, TextReader.PARAM_SOURCE_LOCATION, TXT_DIR, TextReader.PARAM_PATTERNS, TXT_FILE_PATTERN, TextReader.PARAM_LANGUAGE, "en"); AnalysisEngineDescription inferencer = createEngineDescription( MalletEmbeddingsAnnotator.class, MalletEmbeddingsAnnotator.PARAM_MODEL_LOCATION, modelFile, MalletEmbeddingsAnnotator.PARAM_LOWERCASE, true); SimplePipeline.runPipeline(reader, inferencer); } private static void testEmbeddingAnnotations(CollectionReaderDescription reader, AnalysisEngineDescription segmenter, AnalysisEngineDescription inferencer) { int expectedEmbeddingsPerToken = 1; int minTokenLength = 3; // minimum token length in test vector file for (JCas jcas : SimplePipeline.iteratePipeline(reader, segmenter, inferencer)) { for (Token token : select(jcas, Token.class)) { if (token.getCoveredText().length() >= minTokenLength) { assertEquals(expectedEmbeddingsPerToken, selectCovered(WordEmbedding.class, token).size()); } else { assertTrue(selectCovered(WordEmbedding.class, token).isEmpty()); } } } } }