/* * Copyright 2016 * Ubiquitous Knowledge Processing (UKP) Lab * Technische Universität Darmstadt * <p> * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * <p> * http://www.apache.org/licenses/LICENSE-2.0 * <p> * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package de.tudarmstadt.ukp.dkpro.core.mallet.wordembeddings; import de.tudarmstadt.ukp.dkpro.core.api.resources.CompressionMethod; import de.tudarmstadt.ukp.dkpro.core.api.resources.CompressionUtils; import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Sentence; import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token; import de.tudarmstadt.ukp.dkpro.core.io.text.TextReader; import de.tudarmstadt.ukp.dkpro.core.testing.DkproTestContext; import de.tudarmstadt.ukp.dkpro.core.tokit.BreakIteratorSegmenter; import org.apache.uima.UIMAException; import org.apache.uima.analysis_engine.AnalysisEngineDescription; import org.apache.uima.collection.CollectionReaderDescription; import org.apache.uima.fit.pipeline.SimplePipeline; import org.apache.uima.resource.ResourceInitializationException; import org.junit.Rule; import org.junit.Test; import java.io.BufferedReader; import java.io.File; import java.io.IOException; import java.io.InputStreamReader; import java.nio.file.Files; import java.util.Arrays; import java.util.List; import static junit.framework.TestCase.assertEquals; import static org.apache.uima.fit.factory.AnalysisEngineFactory.createEngineDescription; import static org.apache.uima.fit.factory.CollectionReaderFactory.createReaderDescription; import static org.junit.Assert.assertTrue; public class MalletEmbeddingsTrainerTest { @Rule public DkproTestContext testContext = new DkproTestContext(); @Test(timeout = 60000) public void test() throws UIMAException, IOException { int expectedLength = 699; // tag::example[] File text = new File("src/test/resources/txt/*"); File embeddingsFile = new File(testContext.getTestOutputFolder(), "dummy.vec"); int dimensions = 50; String coveringType = Sentence.class.getCanonicalName(); CollectionReaderDescription reader = createReaderDescription(TextReader.class, TextReader.PARAM_SOURCE_LOCATION, text, TextReader.PARAM_LANGUAGE, "en"); AnalysisEngineDescription segmenter = createEngineDescription(BreakIteratorSegmenter.class); AnalysisEngineDescription embeddings = createEngineDescription( MalletEmbeddingsTrainer.class, MalletEmbeddingsTrainer.PARAM_TARGET_LOCATION, embeddingsFile, MalletEmbeddingsTrainer.PARAM_SINGULAR_TARGET, true, MalletEmbeddingsTrainer.PARAM_OVERWRITE, true, MalletEmbeddingsTrainer.PARAM_NUM_THREADS, 1, MalletEmbeddingsTrainer.PARAM_COVERING_ANNOTATION_TYPE, coveringType); SimplePipeline.runPipeline(reader, segmenter, embeddings); // end::example[] List<String> output = Files.readAllLines(embeddingsFile.toPath()); assertEquals(expectedLength, output.size()); /* assert dimensionality for each line */ output.stream() .map(line -> line.split(" ")) /* each line should have 1 + <#dimensions> fields */ .peek(line -> assertEquals(dimensions + 1, line.length)) /* each value must be parsable to a double */ .map(line -> Arrays.copyOfRange(line, 1, dimensions)) /* assert each value can be parsed to double */ .forEach(array -> Arrays.stream(array).forEach(Double::parseDouble)); } @Test(timeout = 60000, expected = ResourceInitializationException.class) public void testNoTarget() throws IOException, UIMAException { File text = new File("src/test/resources/txt/*"); CollectionReaderDescription reader = createReaderDescription(TextReader.class, TextReader.PARAM_SOURCE_LOCATION, text, TextReader.PARAM_LANGUAGE, "en"); AnalysisEngineDescription segmenter = createEngineDescription(BreakIteratorSegmenter.class); AnalysisEngineDescription embeddings = createEngineDescription( MalletEmbeddingsTrainer.class, MalletEmbeddingsTrainer.PARAM_NUM_THREADS, 1); SimplePipeline.runPipeline(reader, segmenter, embeddings); } @Test(timeout = 60000) public void testFilterRegex() throws UIMAException, IOException { File text = new File("src/test/resources/txt/*"); File embeddingsFile = new File(testContext.getTestOutputFolder(), "dummy.vec"); int expectedLength = 629; String coveringType = Sentence.class.getCanonicalName(); String filterRegex = ".*y"; // tokens ending with "y" CollectionReaderDescription reader = createReaderDescription(TextReader.class, TextReader.PARAM_SOURCE_LOCATION, text, TextReader.PARAM_LANGUAGE, "en"); AnalysisEngineDescription segmenter = createEngineDescription(BreakIteratorSegmenter.class); AnalysisEngineDescription embeddings = createEngineDescription( MalletEmbeddingsTrainer.class, MalletEmbeddingsTrainer.PARAM_TARGET_LOCATION, embeddingsFile, MalletEmbeddingsTrainer.PARAM_OVERWRITE, true, MalletEmbeddingsTrainer.PARAM_NUM_THREADS, 1, MalletEmbeddingsTrainer.PARAM_COVERING_ANNOTATION_TYPE, coveringType, MalletEmbeddingsTrainer.PARAM_FILTER_REGEX, filterRegex); SimplePipeline.runPipeline(reader, segmenter, embeddings); List<String> output = Files.readAllLines(embeddingsFile.toPath()); assertEquals(expectedLength, output.size()); /* assert that no token matches filter regex */ assertTrue(output.stream() .map(line -> line.split(" ")) .map(tokens -> tokens[0]) .noneMatch(token -> token.matches(filterRegex))); } @Test(timeout = 60000) public void testCompressed() throws UIMAException, IOException { CompressionMethod compressionMethod = CompressionMethod.GZIP; File text = new File("src/test/resources/txt/*"); File targetDir = testContext.getTestOutputFolder(); File targetFile = new File(targetDir, "embeddings" + compressionMethod.getExtension()); int expectedLength = 699; int dimensions = 50; String covering = Sentence.class.getCanonicalName(); CollectionReaderDescription reader = createReaderDescription(TextReader.class, TextReader.PARAM_SOURCE_LOCATION, text, TextReader.PARAM_LANGUAGE, "en"); AnalysisEngineDescription segmenter = createEngineDescription(BreakIteratorSegmenter.class); AnalysisEngineDescription embeddings = createEngineDescription( MalletEmbeddingsTrainer.class, MalletEmbeddingsTrainer.PARAM_TARGET_LOCATION, targetFile, MalletEmbeddingsTrainer.PARAM_COVERING_ANNOTATION_TYPE, covering, MalletEmbeddingsTrainer.PARAM_NUM_THREADS, 1, MalletEmbeddingsTrainer.PARAM_COMPRESSION, compressionMethod); SimplePipeline.runPipeline(reader, segmenter, embeddings); BufferedReader bufferedReader = new BufferedReader( new InputStreamReader(CompressionUtils.getInputStream( targetFile.getAbsolutePath(), Files.newInputStream(targetFile.toPath())))); String line; int lineCounter = 0; while ((line = bufferedReader.readLine()) != null) { lineCounter++; String[] fields = line.split(" "); assertEquals(dimensions + 1, fields.length); assertTrue(Arrays.stream(fields, 1, fields.length) .mapToDouble(Double::parseDouble) .allMatch(f -> 1 > f && -1 < f)); } assertEquals(expectedLength, lineCounter); bufferedReader.close(); } @Test(timeout = 60000) public void testCharacterEmbeddings() throws IOException, UIMAException { File text = new File("src/test/resources/txt/*"); File embeddingsFile = new File(testContext.getTestOutputFolder(), "embeddings.vec"); int expectedLength = 47; int dimensions = 50; CollectionReaderDescription reader = createReaderDescription(TextReader.class, TextReader.PARAM_SOURCE_LOCATION, text, TextReader.PARAM_LANGUAGE, "en"); AnalysisEngineDescription segmenter = createEngineDescription(BreakIteratorSegmenter.class); AnalysisEngineDescription embeddings = createEngineDescription( MalletEmbeddingsTrainer.class, MalletEmbeddingsTrainer.PARAM_TARGET_LOCATION, embeddingsFile, MalletEmbeddingsTrainer.PARAM_USE_CHARACTERS, true, MalletEmbeddingsTrainer.PARAM_EXAMPLE_WORD, "a", MalletEmbeddingsTrainer.PARAM_NUM_THREADS, 1, MalletEmbeddingsTrainer.PARAM_OVERWRITE, true); SimplePipeline.runPipeline(reader, segmenter, embeddings); List<String> output = Files.readAllLines(embeddingsFile.toPath()); assertEquals(expectedLength, output.size()); output.stream() .map(line -> line.split(" ")) /* each line should have 1 + <#dimensions> fields */ .peek(line -> assertEquals(dimensions + 1, line.length)) /* each value must be parsable to a double */ .map(line -> Arrays.copyOfRange(line, 1, dimensions)) .forEach(array -> Arrays.stream(array).forEach(Double::parseDouble)); } @Test(timeout = 60000) public void testCharacterEmbeddingsTokens() throws IOException, UIMAException { File text = new File("src/test/resources/txt/*"); File embeddingsFile = new File(testContext.getTestOutputFolder(), "embeddings.vec"); int expectedLength = 46; int dimensions = 50; String covering = Token.class.getTypeName(); CollectionReaderDescription reader = createReaderDescription(TextReader.class, TextReader.PARAM_SOURCE_LOCATION, text, TextReader.PARAM_LANGUAGE, "en"); AnalysisEngineDescription segmenter = createEngineDescription(BreakIteratorSegmenter.class); AnalysisEngineDescription embeddings = createEngineDescription( MalletEmbeddingsTrainer.class, MalletEmbeddingsTrainer.PARAM_TARGET_LOCATION, embeddingsFile, MalletEmbeddingsTrainer.PARAM_USE_CHARACTERS, true, MalletEmbeddingsTrainer.PARAM_EXAMPLE_WORD, "a", MalletEmbeddingsTrainer.PARAM_NUM_THREADS, 1, MalletEmbeddingsTrainer.PARAM_OVERWRITE, true, MalletEmbeddingsTrainer.PARAM_COVERING_ANNOTATION_TYPE, covering); SimplePipeline.runPipeline(reader, segmenter, embeddings); List<String> output = Files.readAllLines(embeddingsFile.toPath()); assertEquals(expectedLength, output.size()); output.stream() .map(line -> line.split(" ")) /* each line should have 1 + <#dimensions> fields */ .peek(line -> assertEquals(dimensions + 1, line.length)) /* each value must be parsable to a double */ .map(line -> Arrays.copyOfRange(line, 1, dimensions)) .forEach(array -> Arrays.stream(array).forEach(Double::parseDouble)); } }