/* * Copyright 2016 * Ubiquitous Knowledge Processing (UKP) Lab * Technische Universität Darmstadt * <p> * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * <p> * http://www.apache.org/licenses/LICENSE-2.0 * <p> * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package de.tudarmstadt.ukp.dkpro.core.mallet.wordembeddings; import cc.mallet.topics.WordEmbeddings; import cc.mallet.types.Alphabet; import cc.mallet.types.InstanceList; import de.tudarmstadt.ukp.dkpro.core.api.resources.CompressionUtils; import de.tudarmstadt.ukp.dkpro.core.mallet.MalletModelTrainer; import org.apache.uima.analysis_engine.AnalysisEngineProcessException; import org.apache.uima.fit.descriptor.ConfigurationParameter; import java.io.File; import java.io.IOException; import java.io.OutputStream; import java.io.PrintWriter; /** * Compute word embeddings from the given collection using skip-grams. * <p> * Set {@link #PARAM_TOKEN_FEATURE_PATH} to define what is considered as a token (Tokens, Lemmas, etc.). * <p> * Set {@link #PARAM_COVERING_ANNOTATION_TYPE} to define what is considered a document (sentences, paragraphs, etc.). * * @since 1.9.0 */ public class MalletEmbeddingsTrainer extends MalletModelTrainer { /** * The number of negative samples to be generated for each token (default: 5). */ public static final String PARAM_NUM_NEGATIVE_SAMPLES = "numNegativeSamples"; @ConfigurationParameter(name = PARAM_NUM_NEGATIVE_SAMPLES, mandatory = true, defaultValue = "5") private int numNegativeSamples; /** * The dimensionality of the output word embeddings (default: 50). */ public static final String PARAM_DIMENSIONS = "dimensions"; @ConfigurationParameter(name = PARAM_DIMENSIONS, mandatory = true, defaultValue = "50") private int dimensions; /** * The context size when generating embeddings (default: 5). */ public static final String PARAM_WINDOW_SIZE = "windowSize"; @ConfigurationParameter(name = PARAM_WINDOW_SIZE, mandatory = true, defaultValue = "5") private int windowSize; /** * An example word that is output with its nearest neighbours once in a while (default: null, i.e. none). */ public static final String PARAM_EXAMPLE_WORD = "exampleWord"; @ConfigurationParameter(name = PARAM_EXAMPLE_WORD, mandatory = false) private String exampleWord; /** * Ignore documents with fewer tokens than this value (default: 10). */ public static final String PARAM_MIN_DOCUMENT_LENGTH = "minDocumentLength"; @ConfigurationParameter(name = PARAM_MIN_DOCUMENT_LENGTH, mandatory = true, defaultValue = "10") private int minDocumentLength; @Override public void collectionProcessComplete() throws AnalysisEngineProcessException { InstanceList instanceList = getInstanceList(); Alphabet alphabet = instanceList.getDataAlphabet(); int vocabSize = alphabet.size(); getLogger().info( String.format("Computing word embeddings with %d dimensions for %d tokens...", dimensions, vocabSize)); if (vocabSize * dimensions * 2 > Integer.MAX_VALUE - 12) { throw new AnalysisEngineProcessException(new IllegalStateException(String.format( "Maximum matrix size (number of words * number of columns/dimensions * 2 exceeded: %d * %d * 2 = %d", vocabSize, dimensions, vocabSize * dimensions * 2))); } WordEmbeddings matrix = new WordEmbeddings(alphabet, dimensions, windowSize); matrix.setQueryWord(exampleWord); matrix.setMinDocumentLength(minDocumentLength); matrix.countWords(instanceList); matrix.train(instanceList, getNumThreads(), numNegativeSamples); assert(getTargetLocation() != null); getLogger().info("Writing output to " + getTargetLocation()); File targetFile = new File(getTargetLocation()); if (targetFile.getParentFile() != null) { targetFile.getParentFile().mkdirs(); } try { OutputStream outputStream = CompressionUtils.getOutputStream(targetFile); PrintWriter printWriter = new PrintWriter(outputStream); matrix.write(printWriter); printWriter.close(); } catch (IOException e) { throw new AnalysisEngineProcessException(e); } } }