/*
* Copyright 2016 Paul Dubs & Richard Eckart de Castilho
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.dkpro.core.api.embeddings.binary;
import org.dkpro.core.api.embeddings.VectorizerUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.*;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.util.Locale;
import java.util.Map;
import static org.dkpro.core.api.embeddings.binary.BinaryVectorizer.Header;
/**
* Utility Methods for working with binary dl4j word vector files.
* <p>
* The core of this code has been written in the context of <a href="https://deeplearning4j.org/">dl4j</a>,
* but provides a generic solution to efficiently storing and reading word embeddings with a memory-mapped file.
*
* @author Paul Dubs
* @author Richard Eckart de Castilho
* @see <a href="https://gist.github.com/treo/f5a346d53f89566b51bf88a9a42c67c7">Original source</a>
*/
public class BinaryWordVectorUtils
{
private static final Logger LOG = LoggerFactory.getLogger(BinaryWordVectorUtils.class);
private static final Locale DEFAULT_LOCALE = Locale.US;
/**
* Write a map of token embeddings into binary format. Uses the default locale {@link Locale#US}
* and assume case-sensitivity iff there is any token containing an uppercase letter.
*
* @param vectors a {@code Map<String, float[]>} holding all tokens with embeddings
* @param binaryTarget the target file {@link File}
* @throws IOException if an I/O error occurs
* @see #convertWordVectorsToBinary(Map, boolean, Locale, File)
*/
public static void convertWordVectorsToBinary(Map<String, float[]> vectors, File binaryTarget)
throws IOException
{
boolean caseless = vectors.keySet().stream()
.allMatch(token -> token.equals(token.toLowerCase()));
convertWordVectorsToBinary(vectors, caseless, DEFAULT_LOCALE, binaryTarget);
}
/**
* Write a map of token embeddings into binary format.
*
* @param vectors a {@code Map<String, float[]>} holding all tokens with embeddings
* @param aCaseless if true, tokens are expected to be caseless
* @param aLocale the {@link Locale}
* @param binaryTarget the target file {@link File}
* @throws IOException if an I/O error occurs
*/
public static void convertWordVectorsToBinary(Map<String, float[]> vectors, boolean aCaseless,
Locale aLocale, File binaryTarget)
throws IOException
{
if (vectors.isEmpty()) {
throw new IllegalArgumentException("Word embeddings map must not be empty.");
}
int vectorLength = vectors.values().iterator().next().length;
assert vectors.values().stream().allMatch(v -> v.length == vectorLength);
Header header = prepareHeader(aCaseless, aLocale, vectors.size(), vectorLength);
DataOutputStream output = new DataOutputStream(
new BufferedOutputStream(new FileOutputStream(binaryTarget)));
header.write(output);
LOG.info("Sorting data...");
String[] words = vectors.keySet().stream()
.sorted()
.toArray(String[]::new);
LOG.info("Writing strings...");
for (String word : words) {
output.writeUTF(word);
}
LOG.info("Writing UNK vector...");
{
float[] vector = VectorizerUtils.randomVector(header.getVectorLength());
writeVector(output, vector);
}
LOG.info("Writing vectors...");
for (String word : words) {
float[] vector = vectors.get(word);
writeVector(output, vector);
}
output.close();
}
private static void writeVector(DataOutputStream output, float[] vector)
throws IOException
{
ByteBuffer buffer = ByteBuffer.allocate(vector.length * Float.BYTES);
FloatBuffer floatBuffer = buffer.asFloatBuffer();
floatBuffer.put(vector);
output.write(buffer.array());
}
private static Header prepareHeader(boolean aCaseless,
Locale aLocale, int wordCount, int vectorLength)
{
Header header = new Header();
header.setVersion(1);
header.setWordCount(wordCount);
header.setVectorLength(vectorLength);
header.setCaseless(aCaseless);
header.setLocale(aLocale.toString());
return header;
}
}