/*
* Copyright 2016
* Ubiquitous Knowledge Processing (UKP) Lab
* Technische Universität Darmstadt
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package de.tudarmstadt.ukp.dkpro.core.mallet.wordembeddings;
import de.tudarmstadt.ukp.dkpro.core.api.parameter.ComponentParameters;
import de.tudarmstadt.ukp.dkpro.core.mallet.MalletModelTrainer;
import de.tudarmstadt.ukp.dkpro.core.mallet.type.WordEmbedding;
import org.apache.uima.UimaContext;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.cas.Type;
import org.apache.uima.cas.text.AnnotationFS;
import org.apache.uima.fit.component.JCasAnnotator_ImplBase;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.fit.descriptor.TypeCapability;
import org.apache.uima.fit.util.CasUtil;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.cas.FloatArray;
import org.apache.uima.resource.ResourceInitializationException;
import org.dkpro.core.api.embeddings.Vectorizer;
import org.dkpro.core.api.embeddings.binary.BinaryVectorizer;
import org.dkpro.core.api.embeddings.text.TextFormatVectorizer;
import java.io.File;
import java.io.IOException;
import java.util.Optional;
/**
* Reads word embeddings from a file and adds {@link WordEmbedding} annotations to tokens/lemmas.
*
* @since 1.9.0
*/
@TypeCapability(
inputs = { "de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token" },
outputs = { "de.tudarmstadt.ukp.dkpro.core.mallet.type.WordEmbedding" }
)
public class MalletEmbeddingsAnnotator
extends JCasAnnotator_ImplBase
{
/**
* The file containing the word embeddings.
* <p>
* Currently only supports text file format.
* </p>
*/
public static final String PARAM_MODEL_LOCATION = ComponentParameters.PARAM_MODEL_LOCATION;
@ConfigurationParameter(name = PARAM_MODEL_LOCATION, mandatory = true)
private File modelLocation;
public static final String PARAM_MODEL_IS_BINARY = "modelIsBinary";
@ConfigurationParameter(name = PARAM_MODEL_IS_BINARY, mandatory = true, defaultValue = "false")
private boolean modelIsBinary;
private Vectorizer vectorizer;
/**
* Specify how to handle unknown tokens:
* <ol>
* <li>If this parameter is not specified, unknown tokens are not annotated.</li>
* <li>If an empty float[] is passed, a random vector is generated that is used for each unknown token.</li>
* <li>If a float[] is passed, each unknown token is annotated with that vector. The float must have the same length as the vectors in the model file.</li>
* </ol>
*/
public static final String PARAM_ANNOTATE_UNKNOWN_TOKENS = "annotateUnknownTokens";
@ConfigurationParameter(name = PARAM_ANNOTATE_UNKNOWN_TOKENS, mandatory = true, defaultValue = "false")
private boolean annotateUnknownTokens;
/**
* If set to true (default: false), the first line is interpreted as header line containing the number of entries and the dimensionality.
* This should be set to true for models generated with Word2Vec.
*/
public static final String PARAM_MODEL_HAS_HEADER = "modelHasHeader";
@ConfigurationParameter(name = PARAM_MODEL_HAS_HEADER, mandatory = true, defaultValue = "false")
private boolean modelHasHeader;
/**
* The annotation type to use for the model. Default: {@code de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token}.
* For lemmas, use {@code de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token/lemma/value}
*/
public static final String PARAM_TOKEN_FEATURE_PATH = MalletModelTrainer.PARAM_TOKEN_FEATURE_PATH;
@ConfigurationParameter(name = PARAM_TOKEN_FEATURE_PATH, mandatory = true, defaultValue = "de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token")
private String tokenFeaturePath;
/**
* If set to true (default: false), all tokens are lowercased.
*/
public static final String PARAM_LOWERCASE = "lowercase";
@ConfigurationParameter(name = PARAM_LOWERCASE, mandatory = true, defaultValue = "false")
private boolean lowercase;
@Override
public void initialize(UimaContext context)
throws ResourceInitializationException
{
super.initialize(context);
if (modelHasHeader && modelIsBinary) {
throw new ResourceInitializationException(new IllegalArgumentException(
"The parameter PARAM_MODEL_HAS_HEADER is only valid for text-format model files."));
}
try {
vectorizer = modelIsBinary ?
BinaryVectorizer.load(modelLocation) :
TextFormatVectorizer.load(modelLocation);
}
catch (IOException e) {
throw new ResourceInitializationException(e);
}
if (lowercase != vectorizer.isCaseless()) {
throw new ResourceInitializationException(new IllegalArgumentException(
"If PARAM_LOWERCASE is set, the model should be caseless and vice-versa."));
}
}
@Override
public void process(JCas aJCas)
throws AnalysisEngineProcessException
{
Type type = aJCas.getTypeSystem().getType(tokenFeaturePath);
for (AnnotationFS token : CasUtil.select(aJCas.getCas(), type)) {
try {
addAnnotation(aJCas, token.getCoveredText(), token.getBegin(), token.getEnd());
}
catch (IOException e) {
throw new AnalysisEngineProcessException(e);
}
}
}
private void addAnnotation(JCas aJCas, String text, int begin, int end)
throws IOException
{
if (lowercase) {
text = text.toLowerCase();
}
Optional<float[]> vector = getVector(text);
if (vector.isPresent()) {
WordEmbedding embedding = new WordEmbedding(aJCas, begin, end);
FloatArray array = new FloatArray(aJCas, vector.get().length);
for (int i = 0; i < vector.get().length; i++) {
array.set(i, vector.get()[i]);
}
embedding.setWordEmbedding(array);
embedding.addToIndexes(aJCas);
}
else {
getLogger().debug(text + " not found in embeddings list.");
}
}
/**
* If {@link #PARAM_ANNOTATE_UNKNOWN_TOKENS} is set to true, always return a vector retrieved
* from the vectorizer, which should hold a stable random vector for unknown tokens.
* Otherwise, return a vector for known tokens, or none if the token is unknown.
*
* @param token a token for which to look up an embedding
* @return an {@code Optional<float[]>} that holds the token embedding or is empty if no embeddings is available for the token
* @throws IOException if an I/O error occurs
*/
private Optional<float[]> getVector(String token)
throws IOException
{
if (annotateUnknownTokens) {
return Optional.of(vectorizer.vectorize(token));
}
else {
return vectorizer.contains(token) ?
Optional.of(vectorizer.vectorize(token)) :
Optional.empty();
}
}
}