/** * Copyright (c) 2014, the LESK-WSD-DSM AUTHORS. * * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * Neither the name of the University of Bari nor the names of its contributors * may be used to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. * * GNU GENERAL PUBLIC LICENSE - Version 3, 29 June 2007 * */ package di.uniba.it.wsd.dsm; import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.PriorityQueue; import java.util.Set; import java.util.logging.Level; import java.util.logging.Logger; import org.apache.lucene.store.FSDirectory; import org.apache.lucene.store.IndexInput; /** * This class reads and stores in a Map the word vectors * * @author pierpaolo */ public class LuceneVectorStore implements VectorStore { private IndexInput indexInput; private final Map<String, float[]> memory = new HashMap<>(); /** * * @param file * @throws IOException */ public void init(File file) throws IOException { FSDirectory dir = FSDirectory.open(file.getParentFile()); this.indexInput = dir.openInput(file.getName()); String header = indexInput.readString(); //skip header if ((header.equalsIgnoreCase("-dimensions"))) { ObjectVector.vecLength = indexInput.readInt(); } else if (header.contains("-dimension")) { int index = header.indexOf("-dimension"); ObjectVector.vecLength = Integer.parseInt(header.substring(index + 10).trim()); } loadInRam(); } /** * * @throws IOException */ private void loadInRam() throws IOException { this.indexInput.seek(0); memory.clear(); String header = indexInput.readString(); //skip header if ((header.equalsIgnoreCase("-dimensions"))) { ObjectVector.vecLength = indexInput.readInt(); } else if (header.contains("-dimension")) { int index = header.indexOf("-dimension"); ObjectVector.vecLength = Integer.parseInt(header.substring(index + 10).trim()); } while (indexInput.getFilePointer() < indexInput.length()) { String term = indexInput.readString(); float[] v = new float[ObjectVector.vecLength]; for (int k = 0; k < v.length; k++) { v[k] = Float.intBitsToFloat(indexInput.readInt()); } memory.put(term, v); } Logger.getLogger(LuceneVectorStore.class.getName()).log(Level.INFO, "Loaded {0} vectors", memory.size()); } /** * * @param set * @return * @throws IOException */ public Map<String, float[]> prefetch(Set<String> set) throws IOException { Logger.getLogger(LuceneVectorStore.class.getName()).log(Level.INFO, "Prefetching for {0} vectors", set.size()); this.indexInput.seek(0); Map<String, float[]> map = new HashMap<>(); String header = indexInput.readString(); //skip header if ((header.equalsIgnoreCase("-dimensions"))) { ObjectVector.vecLength = indexInput.readInt(); } else if (header.contains("-dimension")) { int index = header.indexOf("-dimension"); ObjectVector.vecLength = Integer.parseInt(header.substring(index + 10).trim()); } while (indexInput.getFilePointer() < indexInput.length()) { String term = indexInput.readString(); if (set.contains(term)) { float[] v = new float[ObjectVector.vecLength]; for (int k = 0; k < v.length; k++) { v[k] = Float.intBitsToFloat(indexInput.readInt()); } map.put(term, v); } else { this.indexInput.seek(indexInput.getFilePointer() + ObjectVector.vecLength * 4); } } Logger.getLogger(LuceneVectorStore.class.getName()).log(Level.INFO, "Prefetched {0} vectors", map.size()); return map; } /** * * @param term * @return * @throws IOException */ public float[] getFileVector(String term) throws IOException { this.indexInput.seek(0); String header = indexInput.readString(); //skip header if ((header.equalsIgnoreCase("-dimensions"))) { ObjectVector.vecLength = indexInput.readInt(); } else if (header.contains("-dimension")) { int index = header.indexOf("-dimension"); ObjectVector.vecLength = Integer.parseInt(header.substring(index + 10).trim()); } while (indexInput.getFilePointer() < indexInput.length()) { String key = indexInput.readString(); if (key.equals(term)) { float[] v = new float[ObjectVector.vecLength]; for (int k = 0; k < v.length; k++) { v[k] = Float.intBitsToFloat(indexInput.readInt()); } return v; } else { this.indexInput.seek(indexInput.getFilePointer() + ObjectVector.vecLength * 4); } } throw new IOException("Vector for " + term + " not found"); } /** * * @param vector * @param n * @return * @throws IOException */ public List<SpaceResult> findFileSimilar(float[] vector, int n) throws IOException { PriorityQueue<SpaceResult> queue = new PriorityQueue<>(); indexInput.seek(0); String header = indexInput.readString(); //skip header if ((header.equalsIgnoreCase("-dimensions"))) { ObjectVector.vecLength = indexInput.readInt(); } else if (header.contains("-dimension")) { int index = header.indexOf("-dimension"); ObjectVector.vecLength = Integer.parseInt(header.substring(index + 10).trim()); } while (indexInput.getFilePointer() < indexInput.length()) { String key = indexInput.readString(); float[] v = new float[ObjectVector.vecLength]; for (int k = 0; k < v.length; k++) { v[k] = Float.intBitsToFloat(indexInput.readInt()); } float score = VectorUtils.scalarProduct(vector, v); if (queue.size() < n) { queue.offer(new SpaceResult(key, score)); } else { queue.poll(); queue.offer(new SpaceResult(key, score)); } } queue.poll(); List<SpaceResult> list = new ArrayList<>(queue); Collections.sort(list); return list; } /** * * @param term * @return */ @Override public float[] getVector(String term) { return memory.get(term); } /** * * @param word * @param n * @return */ @Override public List<SpaceResult> findSimilar(String word, int n) { float[] v1 = memory.get(word); if (v1 == null) { Logger.getLogger(LuceneVectorStore.class.getName()).log(Level.WARNING, "No vector for term: {0}", word); return new ArrayList<>(); } PriorityQueue<SpaceResult> queue = new PriorityQueue<>(); Iterator<String> iterator = memory.keySet().iterator(); while (iterator.hasNext()) { String key = iterator.next(); float[] v2 = memory.get(key); float score = VectorUtils.scalarProduct(v1, v2); if (queue.size() < n) { queue.offer(new SpaceResult(key, score)); } else { queue.poll(); queue.offer(new SpaceResult(key, score)); } } queue.poll(); List<SpaceResult> list = new ArrayList<>(queue); Collections.sort(list); return list; } /** * * @param vector * @param n * @return */ public List<SpaceResult> findSimilar(float[] vector, int n) { PriorityQueue<SpaceResult> queue = new PriorityQueue<>(); Iterator<String> iterator = memory.keySet().iterator(); while (iterator.hasNext()) { String key = iterator.next(); float[] v2 = memory.get(key); float score = VectorUtils.scalarProduct(vector, v2); if (queue.size() < n) { queue.offer(new SpaceResult(key, score)); } else { queue.poll(); queue.offer(new SpaceResult(key, score)); } } queue.poll(); List<SpaceResult> list = new ArrayList<>(queue); Collections.sort(list); return list; } /** * * @param map * @param word * @param n * @return */ public List<SpaceResult> findSimilar(Map<String, float[]> map, String word, int n) { float[] v1 = map.get(word); if (v1 == null) { Logger.getLogger(LuceneVectorStore.class.getName()).log(Level.WARNING, "No vector for term: {0}", word); return new ArrayList<>(); } PriorityQueue<SpaceResult> queue = new PriorityQueue<>(); Iterator<String> iterator = map.keySet().iterator(); while (iterator.hasNext()) { String key = iterator.next(); float[] v2 = map.get(key); float score = VectorUtils.scalarProduct(v1, v2); if (queue.size() < n) { queue.offer(new SpaceResult(key, score)); } else { queue.poll(); queue.offer(new SpaceResult(key, score)); } } queue.poll(); List<SpaceResult> list = new ArrayList<>(queue); Collections.sort(list); return list; } /** * * @param map * @param vector * @param n * @return */ public List<SpaceResult> findSimilar(Map<String, float[]> map, float[] vector, int n) { PriorityQueue<SpaceResult> queue = new PriorityQueue<>(); Iterator<String> iterator = map.keySet().iterator(); while (iterator.hasNext()) { String key = iterator.next(); float[] v2 = map.get(key); float score = VectorUtils.scalarProduct(vector, v2); if (queue.size() < n) { queue.offer(new SpaceResult(key, score)); } else { queue.poll(); queue.offer(new SpaceResult(key, score)); } } queue.poll(); List<SpaceResult> list = new ArrayList<>(queue); Collections.sort(list); return list; } }