package edu.stanford.nlp.neural;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.math.ArrayMath;
import java.io.*;
import java.util.HashMap;
import java.util.Map;
import java.util.zip.GZIPOutputStream;
/**
* A serializer for reading / writing word vectors.
* This is used to read word2vec in hcoref, and is primarily here
* for its efficient serialization / deserialization protocol, which
* saves/loads the vectors as 16 bit floats.
*
* @author Gabor Angeli
*/
public class VectorMap extends HashMap<String, float[]>{
/**
* The integer type (i.e., number of bits per integer).
*/
private enum itype {
INT8,
INT16,
INT32;
/**
* Get the minimum integer type that will fit this number.
*/
static itype getType(int num) {
itype t = itype.INT32;
if (num < Short.MAX_VALUE) {
t = itype.INT16;
}
if (num < Byte.MAX_VALUE) {
t = itype.INT8;
}
return t;
}
/**
* Read an integer of this type from the given input stream
*/
public int read(DataInputStream in) throws IOException {
switch (this) {
case INT8:
return in.readByte();
case INT16:
return in.readShort();
case INT32:
return in.readInt();
default:
throw new RuntimeException("Unknown itype: " + this);
}
}
/**
* Write an integer of this type to the given output stream
*/
public void write(DataOutputStream out, int value) throws IOException {
switch (this) {
case INT8:
out.writeByte(value);
break;
case INT16:
out.writeShort(value);
break;
case INT32:
out.writeInt(value);
break;
default:
throw new RuntimeException("Unknown itype: " + this);
}
}
}
/**
* Create an empty word vector storage.
*/
public VectorMap() {
super(1024);
}
/**
* Initialize word vectors from a given map.
* @param vectors The word vectors as a simple map.
*/
public VectorMap(Map<String, float[]> vectors) {
super(vectors);
}
/**
* Write the word vectors to a file.
*
* @param file The file to write to.
* @throws IOException Thrown if the file could not be written to.
*/
public void serialize(String file) throws IOException {
try (OutputStream output = new BufferedOutputStream(new FileOutputStream(new File(file)))) {
if (file.endsWith(".gz")) {
try (GZIPOutputStream gzip = new GZIPOutputStream(output)) {
serialize(gzip);
}
} else {
serialize(output);
}
}
}
/**
* Write the word vectors to an output stream. The stream is not closed on finishing
* the function.
*
* @param out The stream to write to.
* @throws IOException Thrown if the stream could not be written to.
*/
public void serialize(OutputStream out) throws IOException {
DataOutputStream dataOut = new DataOutputStream(out);
// Write some length statistics
int maxKeyLength = 0;
int vectorLength = 0;
for (Entry<String, float[]> entry : this.entrySet()) {
maxKeyLength = Math.max(entry.getKey().getBytes().length, maxKeyLength);
vectorLength = entry.getValue().length;
}
itype keyIntType = itype.getType(maxKeyLength);
// Write the key length
dataOut.writeInt(maxKeyLength);
// Write the vector dim
dataOut.writeInt(vectorLength);
// Write the size of the dataset
dataOut.writeInt(this.size());
for (Map.Entry<String, float[]> entry : this.entrySet()) {
// Write the length of the key
byte[] key = entry.getKey().getBytes();
keyIntType.write(dataOut, key.length);
dataOut.write(key);
// Write the vector
for (float v : entry.getValue()) {
dataOut.writeShort(fromFloat(v));
}
}
}
/**
* Read word vectors from a file or classpath or url.
*
* @param file The file to read from.
* @return The vectors in the file.
* @throws IOException Thrown if we could not read from the resource
*/
public static VectorMap deserialize(String file) throws IOException {
try (InputStream input = IOUtils.getInputStreamFromURLOrClasspathOrFileSystem(file)) {
return deserialize(input);
}
}
/**
* Read word vectors from an input stream. The stream is not closed on finishing the function.
*
* @param in The stream to read from. This is not closed.
* @return The word vectors encoded on the stream.
* @throws IOException Thrown if we could not read from the stream.
*/
public static VectorMap deserialize(InputStream in) throws IOException {
DataInputStream dataIn = new DataInputStream(in);
// Read the max key length
itype keyIntType = itype.getType(dataIn.readInt());
// Read the vector dimensionality
int dim = dataIn.readInt();
// Read the size of the dataset
int size = dataIn.readInt();
// Read the vectors
VectorMap vectors = new VectorMap();
for (int i = 0; i < size; ++i) {
// Read the key
int strlen = keyIntType.read(dataIn);
byte[] buffer = new byte[strlen];
if (dataIn.read(buffer, 0, strlen) != strlen) {
throw new IOException("Could not read string buffer fully!");
}
String key = new String(buffer);
// Read the vector
float[] vector = new float[dim];
for (int k = 0; k < vector.length; ++k) {
vector[k] = toFloat(dataIn.readShort());
}
// Add the key/value
vectors.put(key, vector);
}
return vectors;
}
/**
* Read the Word2Vec word vector flat txt file.
*
* @param file The word2vec text file.
* @return The word vectors in the file.
*/
public static VectorMap readWord2Vec(String file) {
VectorMap vectors = new VectorMap();
int dim = -1;
for(String line : IOUtils.readLines(file)){
String[] split = line.toLowerCase().split("\\s+");
if(split.length < 100) continue;
float[] vector = new float[split.length-1];
if (dim == -1) {
dim = vector.length;
}
assert dim == vector.length;
for(int i=1; i < split.length ; i++) {
vector[i-1] = Float.parseFloat(split[i]);
}
ArrayMath.L2normalize(vector);
vectors.put(split[0], vector);
}
return vectors;
}
@SuppressWarnings("unchecked")
@Override
public boolean equals(Object other) {
if (other instanceof Map) {
try {
Map<String, float[]> otherMap = (Map<String, float[]>) other;
// Key sets have the same size
if (this.keySet().size() != otherMap.keySet().size()) {
return false;
}
// Entries are the same
for (Entry<String, float[]> entry : this.entrySet()) {
float[] otherValue = otherMap.get(entry.getKey());
// Null checks
if (otherValue == null && entry.getValue() != null) {
return false;
}
if (otherValue != null && entry.getValue() == null) {
return false;
}
// Entries are the same
//noinspection ConstantConditions
if (entry.getValue() != null && otherValue != null) {
// Vectors are the same length
if (entry.getValue().length != otherValue.length) {
return false;
}
// Vectors are the same value
for (int i = 0; i < otherValue.length; ++i) {
if (!sameFloat(entry.getValue()[i], otherValue[i])) {
return false;
}
}
}
}
return true;
} catch (ClassCastException e) {
e.printStackTrace();
return false;
}
} else {
return false;
}
}
@Override
public int hashCode() {
return keySet().hashCode();
}
@Override
public String toString() {
return "VectorMap[" + this.size() + "]";
}
/**
* The check to see if two floats are "close enough."
*/
private static boolean sameFloat(float a, float b) {
float absDiff = Math.abs(a - b);
float absA = Math.abs(a);
float absB = Math.abs(b);
return absDiff < 1e-10 ||
absDiff < Math.max(absA, absB) / 100.0f ||
(absA < 1e-5 && absB < 1e-5);
}
/**
* From http://stackoverflow.com/questions/6162651/half-precision-floating-point-in-java
*/
private static float toFloat( short hbits ) {
int mant = hbits & 0x03ff; // 10 bits mantissa
int exp = hbits & 0x7c00; // 5 bits exponent
if( exp == 0x7c00 ) // NaN/Inf
exp = 0x3fc00; // -> NaN/Inf
else if( exp != 0 ) // normalized value
{
exp += 0x1c000; // exp - 15 + 127
if( mant == 0 && exp > 0x1c400 ) // smooth transition
return Float.intBitsToFloat( ( hbits & 0x8000 ) << 16
| exp << 13 | 0x3ff );
}
else if( mant != 0 ) // && exp==0 -> subnormal
{
exp = 0x1c400; // make it normal
do {
mant <<= 1; // mantissa * 2
exp -= 0x400; // decrease exp by 1
} while( ( mant & 0x400 ) == 0 ); // while not normal
mant &= 0x3ff; // discard subnormal bit
} // else +/-0 -> +/-0
return Float.intBitsToFloat( // combine all parts
( hbits & 0x8000 ) << 16 // sign << ( 31 - 15 )
| ( exp | mant ) << 13 ); // value << ( 23 - 10 )
}
/**
* From http://stackoverflow.com/questions/6162651/half-precision-floating-point-in-java
*/
private static short fromFloat( float fval ) {
int fbits = Float.floatToIntBits( fval );
int sign = fbits >>> 16 & 0x8000; // sign only
int val = ( fbits & 0x7fffffff ) + 0x1000; // rounded value
if( val >= 0x47800000 ) // might be or become NaN/Inf
{ // avoid Inf due to rounding
if( ( fbits & 0x7fffffff ) >= 0x47800000 )
{ // is or must become NaN/Inf
if( val < 0x7f800000 ) // was value but too large
return (short) (sign | 0x7c00); // make it +/-Inf
return (short) (sign | 0x7c00 | // remains +/-Inf or NaN
( fbits & 0x007fffff ) >>> 13); // keep NaN (and Inf) bits
}
return (short) (sign | 0x7bff); // unrounded not quite Inf
}
if( val >= 0x38800000 ) // remains normalized value
return (short) (sign | val - 0x38000000 >>> 13); // exp - 127 + 15
if( val < 0x33000000 ) // too small for subnormal
return (short) sign; // becomes +/-0
val = ( fbits & 0x7fffffff ) >>> 23; // tmp exp for subnormal calc
return (short) (sign | ( ( fbits & 0x7fffff | 0x800000 ) // add subnormal bit
+ ( 0x800000 >>> val - 102 ) // round depending on cut off
>>> 126 - val )); // div by 2^(1-(exp-127+15)) and >> 13 | exp=0
}
}