package org.wikibrain.sr.word2vec; import com.typesafe.config.Config; import gnu.trove.list.TByteList; import gnu.trove.list.TCharList; import gnu.trove.list.array.TByteArrayList; import gnu.trove.list.array.TCharArrayList; import gnu.trove.map.TLongIntMap; import gnu.trove.map.hash.TLongIntHashMap; import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.wikibrain.conf.Configuration; import org.wikibrain.conf.ConfigurationException; import org.wikibrain.conf.Configurator; import org.wikibrain.core.dao.DaoException; import org.wikibrain.core.dao.LocalPageDao; import org.wikibrain.core.lang.Language; import org.wikibrain.matrix.DenseMatrix; import org.wikibrain.matrix.DenseMatrixRow; import org.wikibrain.matrix.DenseMatrixWriter; import org.wikibrain.matrix.ValueConf; import org.wikibrain.sr.Explanation; import org.wikibrain.sr.SRResult; import org.wikibrain.sr.vector.DenseVectorGenerator; import org.wikibrain.utils.WpIOUtils; import java.io.*; import java.util.List; import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Reads in a word2vec model in the "standard" file format. * * Builds a disk * * This code is adapted from https://github.com/ansjsun/Word2VEC_java * * @author Shilad Sen */ public class Word2VecGenerator implements DenseVectorGenerator { private static final Logger LOG = LoggerFactory.getLogger(Word2VecGenerator.class); private final Language language; private final LocalPageDao localPageDao; private final File path; private TLongIntMap phraseIds; private DenseMatrix phraseMatrix; private DenseMatrix articleMatrix; public Word2VecGenerator(Language language, LocalPageDao localPageDao, File path) throws IOException { this.language = language; this.localPageDao = localPageDao; this.path = path; this.read(); } public void read() throws IOException { if (getArticleMatrixPath().exists() && getPhraseMatrixPath().exists() && getPhraseIdPath().exists() && getPhraseMatrixPath().lastModified() >= path.lastModified() && getArticleMatrixPath().lastModified() >= path.lastModified()) { LOG.info("phrase and article caches are up to date, loading them..."); phraseMatrix = new DenseMatrix(getPhraseMatrixPath()); articleMatrix = new DenseMatrix(getArticleMatrixPath()); readPhraseIds(); } else { createWikiBrainModel(); } } private void readPhraseIds() throws IOException { BufferedReader reader = WpIOUtils.openBufferedReader(getPhraseIdPath()); try { phraseIds = new TLongIntHashMap(); while (true) { String line = reader.readLine(); if (line == null) { break; } String tokens[] = line.split("\t", 2); int wpId = Integer.parseInt(tokens[0]); String phrase = tokens[1].trim(); phraseIds.put(hashWord(phrase), wpId); } } finally { IOUtils.closeQuietly(reader); } } private void createWikiBrainModel() throws IOException { FileUtils.deleteQuietly(getPhraseIdPath()); FileUtils.deleteQuietly(getPhraseMatrixPath()); FileUtils.deleteQuietly(getArticleMatrixPath()); ValueConf vconf = new ValueConf(); BufferedWriter phraseIdWriter = WpIOUtils.openWriter(getPhraseIdPath()); DenseMatrixWriter phraseWriter = new DenseMatrixWriter(getPhraseMatrixPath(), vconf); DenseMatrixWriter articleWriter = new DenseMatrixWriter(getArticleMatrixPath(), vconf); DataInputStream dis = null; InputStream bis = null; try { bis = WpIOUtils.openInputStream(path); dis = new DataInputStream(bis); String header = ""; while (true) { char c = (char) dis.read(); if (c == '\n') break; header += c; } String tokens[] = header.split(" "); int numEntities = Integer.parseInt(tokens[0]); int vlength = Integer.parseInt(tokens[1]); LOG.info("preparing to read " + numEntities + " with length " + vlength + " vectors"); int [] colIds = new int[vlength]; for (int i = 0; i < vlength; i++) { colIds[i] = i; } int numPhrases = 0; int numArticles = 0; for (int i = 0; i < numEntities; i++) { String word = readString(dis); if (i % 5000 == 0) { LOG.info("Read word vector " + word + " (" + i + " of " + numEntities + ")"); } float[] vector = new float[vlength]; double norm2 = 0.0; for (int j = 0; j < vlength; j++) { float val = readFloat(dis); norm2 += val * val; vector[j] = val; } norm2 = Math.sqrt(norm2); for (int j = 0; j < vlength; j++) { vector[j] /= norm2; } if (word.startsWith("/w/")) { String[] pieces = word.split("/", 5); int wpId = Integer.valueOf(pieces[3]); if (wpId >= 0) { DenseMatrixRow row = new DenseMatrixRow(vconf, wpId, colIds, vector); articleWriter.writeRow(row); numArticles++; } } else { word = word.replace('\t', ' ').replace('\n', ' '); DenseMatrixRow row = new DenseMatrixRow(vconf, numPhrases, colIds, vector); phraseWriter.writeRow(row); phraseIdWriter.write(numPhrases + "\t" + word + "\n"); numPhrases++; } } if (numPhrases == 0) { phraseWriter.writeRow(new DenseMatrixRow(vconf, 0, colIds, new float[vlength])); } if (numArticles == 0) { articleWriter.writeRow(new DenseMatrixRow(vconf, 0, colIds, new float[vlength])); } } finally { IOUtils.closeQuietly(bis); IOUtils.closeQuietly(dis); } IOUtils.closeQuietly(phraseIdWriter); phraseWriter.finish(); articleWriter.finish(); phraseMatrix = new DenseMatrix(getPhraseMatrixPath()); articleMatrix = new DenseMatrix(getArticleMatrixPath()); readPhraseIds(); } private File getPhraseMatrixPath() { return new File(path.getAbsolutePath() + ".phrases.matrix"); } private File getArticleMatrixPath() { return new File(path.getAbsolutePath() + ".articles.matrix"); } private File getPhraseIdPath() { return new File(path.getAbsolutePath() + ".phrases.txt"); } private static String readString(DataInputStream dis) throws IOException { TByteList bytes = new TByteArrayList(); while (true) { int i = dis.read(); if (i == -1) { break; } if (i < 0 || i > 255) { throw new IllegalStateException(); } char c = (char)i; if (c == ' ') { break; } if (c != '\n') { bytes.add((byte)i); } } return new String(bytes.toArray(), "UTF-8"); } private static float readFloat(InputStream is) throws IOException { byte[] bytes = new byte[4]; is.read(bytes); return getFloat(bytes); } private static float getFloat(byte[] b) { int accum = 0; accum = accum | (b[0] & 0xff) << 0; accum = accum | (b[1] & 0xff) << 8; accum = accum | (b[2] & 0xff) << 16; accum = accum | (b[3] & 0xff) << 24; return Float.intBitsToFloat(accum); } @Override public DenseMatrix getFeatureMatrix() { return articleMatrix; } @Override public float [] getVector(int pageId) throws DaoException { try { DenseMatrixRow row = articleMatrix.getRow(pageId); return row == null ? null : row.getValues(); } catch (IOException e) { throw new DaoException(e); } } @Override public float [] getVector(String phrase) { try { long hash = hashWord(phrase); if (phraseIds.containsKey(hash)) { int phraseId = phraseIds.get(hash); DenseMatrixRow row = phraseMatrix.getRow(phraseId); return row == null ? null : row.getValues(); } else { return null; } } catch (IOException e) { throw new RuntimeException(e); } } private static long hashWord(String word) { return Word2VecUtils.hashWord(normalize(word)); } @Override public List<Explanation> getExplanations(String phrase1, String phrase2, float [] vector1, float [] vector2, SRResult result) throws DaoException { throw new UnsupportedOperationException(); } @Override public List<Explanation> getExplanations(int pageID1, int pageID2, float [] vector1, float [] vector2, SRResult result) throws DaoException { return null; } private static String normalize(String s) { return s.replace('_', ' ').trim(); } public static class Provider extends org.wikibrain.conf.Provider<DenseVectorGenerator> { public Provider(Configurator configurator, Configuration config) throws ConfigurationException { super(configurator, config); } @Override public Class<DenseVectorGenerator> getType() { return DenseVectorGenerator.class; } @Override public String getPath() { return "sr.metric.densegenerator"; } @Override public DenseVectorGenerator get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException { if (!config.getString("type").equals("word2vec")) { return null; } if (!runtimeParams.containsKey("language")) { throw new IllegalArgumentException("Monolingual SR Metric requires 'language' runtime parameter"); } Language language = Language.getByLangCode(runtimeParams.get("language")); File path = getModelFile(config.getString("modelDir"), language); if (!path.isFile()) { throw new ConfigurationException("Path to word2vec model " + path.getAbsolutePath() + " is not a file. Do you need to download or build the model?"); } try { return new Word2VecGenerator( language, getConfigurator().get(LocalPageDao.class), path ); } catch (IOException e) { throw new ConfigurationException(e); } } } public static File getModelFile(String dir, Language lang) { return getModelFile(new File(dir), lang); } public static File getModelFile(File dir, Language lang) { return new File(dir, lang.getLangCode() + ".bin"); } public static void main(String args[]) throws IOException { Word2VecGenerator gen = new Word2VecGenerator(null, null, new File("/Users/a558989/Projects/wikibrain/base-bh/dat/word2vecRaw/bh.bin")); } }