/*
* Copyright 2014 Radialpoint SafeCare Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.radialpoint.word2vec;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.HashMap;
import java.util.Map;
/**
* This class stores the mapping of String->array of float that constitutes each vector.
*
* The class can serialize to/from a stream.
*
* The ConvertVectors allows to transform the C binary vectors into instances of this class.
*/
public class Vectors {
/**
* The vectors themselves.
*/
protected float[][] vectors;
/**
* The words associated with the vectors
*/
protected String[] vocabVects;
/**
* Size of each vector
*/
protected int size;
/**
* Inverse map, word-> index
*/
protected Map<String, Integer> vocab;
/**
* Package-level constructor, used by the ConvertVectors program.
*
* @param vectors
* , it cannot be empty
* @param vocabVects
* , the length should match vectors
*/
Vectors(float[][] vectors, String[] vocabVects) throws VectorsException {
this.vectors = vectors;
this.size = vectors[0].length;
if (vectors.length != vocabVects.length)
throw new VectorsException("Vectors and vocabulary size mismatch");
this.vocabVects = vocabVects;
this.vocab = new HashMap<String, Integer>();
for (int i = 0; i < vocabVects.length; i++)
vocab.put(vocabVects[i], i);
}
/**
* Initialize a Vectors instance from an open input stream. This method closes the stream.
*
* @param is
* the open stream
* @throws IOException
* if there are problems reading from the stream
*/
public Vectors(InputStream is) throws IOException {
DataInputStream dis = new DataInputStream(is);
int words = dis.readInt();
int size = dis.readInt();
this.size = size;
this.vectors = new float[words][];
this.vocabVects = new String[words];
for (int i = 0; i < words; i++) {
this.vocabVects[i] = dis.readUTF();
float[] vector = new float[size];
for (int j = 0; j < size; j++)
vector[j] = dis.readFloat();
this.vectors[i] = vector;
}
this.vocab = new HashMap<String, Integer>();
for (int i = 0; i < vocabVects.length; i++)
vocab.put(vocabVects[i], i);
dis.close();
}
/**
* Writes this vector to an open output stream. This method closes the stream.
*
* @param os
* the stream to write to
* @throws IOException
* if there are problems writing to the stream
*/
public void writeTo(OutputStream os) throws IOException {
DataOutputStream dos = new DataOutputStream(os);
dos.writeInt(this.vectors.length);
dos.writeInt(this.size);
for (int i = 0; i < vectors.length; i++) {
dos.writeUTF(this.vocabVects[i]);
for (int j = 0; j < size; j++)
dos.writeFloat(this.vectors[i][j]);
}
dos.close();
}
public float[][] getVectors() {
return vectors;
}
public float[] getVector(int i) {
return vectors[i];
}
public float[] getVector(String term) throws OutOfVocabularyException {
Integer idx = vocab.get(term);
if (idx == null)
throw new OutOfVocabularyException("Unknown term '" + term + "'");
return vectors[idx];
}
public int getIndex(String term) throws OutOfVocabularyException {
Integer idx = vocab.get(term);
if (idx == null)
throw new OutOfVocabularyException("Unknown term '" + term + "'");
return idx;
}
public Integer getIndexOrNull(String term) {
return vocab.get(term);
}
public String getTerm(int index) {
return vocabVects[index];
}
public Map<String, Integer> getVocabulary() {
return vocab;
}
public boolean hasTerm(String term) {
return vocab.containsKey(term);
}
public int vectorSize() {
return size;
}
public int wordCount() {
return vectors.length;
}
}