/*
* Copyright 2016 Paul Dubs & Richard Eckart de Castilho
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.dkpro.core.api.embeddings.binary;
import org.dkpro.core.api.embeddings.Vectorizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.*;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Locale;
/**
* A {@link Vectorizer} for a binary file. Initialize with {@link #load(File)}.
*
* @see BinaryWordVectorUtils
*/
public class BinaryVectorizer
implements Vectorizer
{
private static final Logger LOG = LoggerFactory.getLogger(BinaryVectorizer.class);
private final String[] words;
private final Header header;
private final FloatBuffer[] parts;
private final int maxVectorsPerPartition;
private Locale locale;
private float[] unknownVector;
private BinaryVectorizer(Header aHeader, RandomAccessFile file, String[] aWords,
long vectorStartOffset, float[] aUnk)
throws IOException
{
header = aHeader;
words = aWords;
unknownVector = aUnk;
locale = Locale.forLanguageTag(header.getLocale());
// Integers can address up to 2 GB (Integer.MAX_VALUE) - to handle large embeddings
// files, we partition the file into parts of up to 2 GB each.
maxVectorsPerPartition = Integer.MAX_VALUE / (header.getVectorLength() * Float.BYTES);
int maxPartitionSizeBytes = maxVectorsPerPartition * header.getVectorLength() * Float.BYTES;
int neededPartitions = aWords.length / maxVectorsPerPartition;
if (aWords.length % maxPartitionSizeBytes > 0) {
neededPartitions += 1;
}
parts = new FloatBuffer[neededPartitions];
FileChannel channel = file.getChannel();
for (int i = 0; i < neededPartitions; i++) {
long start = vectorStartOffset + ((long) i * maxPartitionSizeBytes);
long length = maxPartitionSizeBytes;
if (i == neededPartitions - 1) {
length = (aWords.length % maxVectorsPerPartition) * header.getVectorLength()
* Float.BYTES;
}
parts[i] = channel.map(FileChannel.MapMode.READ_ONLY, start, length).asFloatBuffer();
}
}
/**
* Load a binary embeddings file and return a new {@link BinaryVectorizer} object.
*
* @param f a {@link File}
* @return a new {@link BinaryVectorizer}
* @throws IOException if an I/O error occurs
*/
public static BinaryVectorizer load(File f)
throws IOException
{
RandomAccessFile file = new RandomAccessFile(f, "rw");
// Load header
Header header = Header.read(file);
// Load words
String[] words = new String[header.getWordCount()];
for (int i = 0; i < header.getWordCount(); i++) {
words[i] = file.readUTF();
}
LOG.info("Loaded " + words.length + " word embeddings.");
// Load UNK vector
byte[] buffer = new byte[header.getVectorLength() * Float.BYTES];
file.readFully(buffer);
ByteBuffer byteBuffer = ByteBuffer.wrap(buffer);
float[] unk = new float[header.getVectorLength()];
for (int i = 0; i < unk.length; i++) {
unk[i] = byteBuffer.getFloat(i * Float.BYTES);
}
// Rest of the file is mmapped
long offset = file.getFilePointer();
return new BinaryVectorizer(header, file, words, offset, unk);
}
@Override public float[] vectorize(String aWord)
throws IOException
{
String word = aWord;
if (header.isCaseless()) {
word = word.toLowerCase(locale);
}
int vectorIdx = Arrays.binarySearch(words, word);
// Word not found
if (vectorIdx < 0) {
return unknownVector;
}
// Locate the buffer from which to read the vevtor
int partitionIdx = vectorIdx / maxVectorsPerPartition;
FloatBuffer part = this.parts[partitionIdx];
// Locate the position within the buffer from which to read the vector
int relativeVectorIdx = vectorIdx % maxVectorsPerPartition;
int offset = relativeVectorIdx * header.getVectorLength();
part.position(offset);
// Read the vector
float[] vector = new float[header.getVectorLength()];
part.get(vector);
return vector;
}
@Override public boolean contains(String aWord)
{
String word = aWord;
if (header.isCaseless()) {
word = word.toLowerCase(locale);
}
return Arrays.binarySearch(words, word) >= 0;
}
@Override public float[] unknownVector()
{
return unknownVector;
}
@Override public int dimensions()
{
return header.getVectorLength();
}
@Override public int size()
{
return header.getWordCount();
}
@Override public boolean isCaseless()
{
return header.isCaseless();
}
static class Header
{
private static final String MAGIC = "dl4jw2v";
private int version = 1;
private int wordCount;
private int vectorLength;
private boolean caseless;
private String locale;
public static Header read(DataInput aInput)
throws IOException
{
byte[] magicBytes = new byte[MAGIC.length()];
aInput.readFully(magicBytes);
if (!MAGIC.equals(new String(magicBytes, StandardCharsets.US_ASCII))) {
throw new IOException(
"The file you provided is either not a DL4J binary word vectors file or corrupted.");
}
Header header = new Header();
header.version = aInput.readByte();
if (1 != header.version) {
throw new IOException("Not supported file format version.");
}
header.wordCount = aInput.readInt();
header.vectorLength = aInput.readInt();
header.caseless = aInput.readBoolean();
header.locale = aInput.readUTF();
return header;
}
public int getVersion()
{
return version;
}
public void setVersion(int version)
{
this.version = version;
}
public int getWordCount()
{
return wordCount;
}
public void setWordCount(int wordCount)
{
this.wordCount = wordCount;
}
public boolean isCaseless()
{
return caseless;
}
public void setCaseless(boolean caseless)
{
this.caseless = caseless;
}
public String getLocale()
{
return locale;
}
public void setLocale(String locale)
{
this.locale = locale;
}
public int getVectorLength()
{
return vectorLength;
}
public void setVectorLength(int vectorLength)
{
this.vectorLength = vectorLength;
}
public void write(OutputStream aOutput)
throws IOException
{
DataOutputStream out = new DataOutputStream(aOutput);
// Magic String to make file recognition easier
out.write(MAGIC.getBytes(StandardCharsets.US_ASCII));
out.writeByte(version);
out.writeInt(wordCount);
out.writeInt(vectorLength);
out.writeBoolean(caseless);
out.writeUTF(locale);
out.flush();
}
}
}