/* * Copyright 2014 Radialpoint SafeCare Inc. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.radialpoint.word2vec; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.util.HashMap; import java.util.Map; /** * This class stores the mapping of String->array of float that constitutes each vector. * * The class can serialize to/from a stream. * * The ConvertVectors allows to transform the C binary vectors into instances of this class. */ public class Vectors { /** * The vectors themselves. */ protected float[][] vectors; /** * The words associated with the vectors */ protected String[] vocabVects; /** * Size of each vector */ protected int size; /** * Inverse map, word-> index */ protected Map<String, Integer> vocab; /** * Package-level constructor, used by the ConvertVectors program. * * @param vectors * , it cannot be empty * @param vocabVects * , the length should match vectors */ Vectors(float[][] vectors, String[] vocabVects) throws VectorsException { this.vectors = vectors; this.size = vectors[0].length; if (vectors.length != vocabVects.length) throw new VectorsException("Vectors and vocabulary size mismatch"); this.vocabVects = vocabVects; this.vocab = new HashMap<String, Integer>(); for (int i = 0; i < vocabVects.length; i++) vocab.put(vocabVects[i], i); } /** * Initialize a Vectors instance from an open input stream. This method closes the stream. * * @param is * the open stream * @throws IOException * if there are problems reading from the stream */ public Vectors(InputStream is) throws IOException { DataInputStream dis = new DataInputStream(is); int words = dis.readInt(); int size = dis.readInt(); this.size = size; this.vectors = new float[words][]; this.vocabVects = new String[words]; for (int i = 0; i < words; i++) { this.vocabVects[i] = dis.readUTF(); float[] vector = new float[size]; for (int j = 0; j < size; j++) vector[j] = dis.readFloat(); this.vectors[i] = vector; } this.vocab = new HashMap<String, Integer>(); for (int i = 0; i < vocabVects.length; i++) vocab.put(vocabVects[i], i); dis.close(); } /** * Writes this vector to an open output stream. This method closes the stream. * * @param os * the stream to write to * @throws IOException * if there are problems writing to the stream */ public void writeTo(OutputStream os) throws IOException { DataOutputStream dos = new DataOutputStream(os); dos.writeInt(this.vectors.length); dos.writeInt(this.size); for (int i = 0; i < vectors.length; i++) { dos.writeUTF(this.vocabVects[i]); for (int j = 0; j < size; j++) dos.writeFloat(this.vectors[i][j]); } dos.close(); } public float[][] getVectors() { return vectors; } public float[] getVector(int i) { return vectors[i]; } public float[] getVector(String term) throws OutOfVocabularyException { Integer idx = vocab.get(term); if (idx == null) throw new OutOfVocabularyException("Unknown term '" + term + "'"); return vectors[idx]; } public int getIndex(String term) throws OutOfVocabularyException { Integer idx = vocab.get(term); if (idx == null) throw new OutOfVocabularyException("Unknown term '" + term + "'"); return idx; } public Integer getIndexOrNull(String term) { return vocab.get(term); } public String getTerm(int index) { return vocabVects[index]; } public Map<String, Integer> getVocabulary() { return vocab; } public boolean hasTerm(String term) { return vocab.containsKey(term); } public int vectorSize() { return size; } public int wordCount() { return vectors.length; } }