package com.cyc.tool.distributedrepresentations; /* * #%L * DistributedRepresentations * %% * Copyright (C) 2015 Cycorp, Inc * %% * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * #L% */ import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.concurrent.ConcurrentNavigableMap; import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.IntStream; import org.mapdb.DB; /** * A space of words from Google Word2Vec * */ public abstract class Word2VecSpace { private int size; DB db; Map<String, float[]> vectors; long words; /** * * @param terms * @return a List of Strings containing nGrams for terms */ public static List<String> nGramsFor(List<String> terms) { final List<String> grams = new ArrayList<String>(); IntStream.rangeClosed(1, terms.size()).forEach(length -> { IntStream.rangeClosed(0, terms.size() - length).forEach(start -> { List<String> l = terms.subList(start, start + length); grams.add(String.join(" ", l)); }); }); return grams; } private static String norm(String term) { return term.replaceAll("\\s+", "_"); } private double cosineSimilarity(float[] v1, float[] v2) { return dotProduct(v1, v2) / (magnitude(v1) * magnitude(v2)); } /** * * @param t1 * @param t2 * @return the cosine similarity */ public double cosineSimilarity(String t1, String t2) { return cosineSimilarity(getVector(t1), getVector(t2)); } private double dotProduct(float[] v1, float[] v2) { return IntStream.range(0, v1.length) .mapToDouble(i -> (double) v1[i] * (double) v2[i]) .sum(); } private double euclidianDistance(float[] v1, float[] v2) { double dist = Math.sqrt(IntStream.range(0, v1.length) .mapToDouble(i -> Math.pow((double) v1[i] - (double) v2[i], 2)) .sum()); return dist; } private double euclidianDistance(String t1, String t2) { return euclidianDistance(getVector(t1), getVector(t2)); } private float[] getAverageVector(List<String> terms) { final float sum[] = new float[size]; final double mult = 1.0 / terms.size(); terms.forEach(s -> { float v[] = getVector(s); IntStream.range(0, size) .forEach(i -> { sum[i] += mult * v[i]; }); }); return sum; } /** * * @return the db */ public DB getDb() { return db; } /** * Set up the DB. * * @param db */ public void setDb(DB db) { this.db = db; } /** * * @param terms * @return the sum of term vectors divided by vector length * @throws NoWordToVecVectorForTerm */ public float[] getGoogleNormedVector(List<String> terms) throws NoWordToVecVectorForTerm { // Sum of term vectors divided by vector length // Note that this will miss multi-word exact matches, so prefer getMaximalNormedVector //except for exact code comparison tests final float sum[] = new float[size]; if (terms.stream().allMatch(s -> !knownTerm(s))) { throw new NoWordToVecVectorForTerm("Can't find vector for:" + String.join(", ", terms)); } terms.stream() .filter(s -> knownTerm(s)) .forEach(s -> { float v[] = getVector(s); IntStream.range(0, size) .forEach(i -> { sum[i] += v[i]; }); }); return normVector(sum); } /** * * @param interms * @return the maximal normed vector * @throws NoWordToVecVectorForTerm */ public float[] getMaximalNormedVector(List<String> interms) throws NoWordToVecVectorForTerm { // Sum of term ngram vectors divided by vector length List<String> terms = nGramsFor(interms); final float sum[] = new float[size]; if (terms.stream().allMatch(s -> !knownTerm(s))) { throw new NoWordToVecVectorForTerm("Can't find vector for:" + String.join(", ", terms)); } terms.stream() .filter(s -> knownTerm(s)) .forEach(s -> { float v[] = getVector(s); IntStream.range(0, size) .forEach(i -> { sum[i] += v[i]; }); }); return normVector(sum); } /** * * @return size of vectors */ public int getNVectors() { return vectors.size(); } /** * * @return size of the Word2VecSpace */ public int getSize() { return size; } /** * * @param size */ public void setSize(int size) { this.size = size; } /** * * @param term * @return the vector for term */ public float[] getVector(String term) { return vectors.get(norm(term)); } /** * * @return the vectors */ public Map<String, float[]> getVectors() { return vectors; } /** * * @param vectors */ public void setVectors(ConcurrentNavigableMap<String, float[]> vectors) { this.vectors = vectors; } /** * * @return the words */ public long getWords() { return words; } /** * * @param words */ public void setWords(long words) { this.words = words; } /** * * @param v1 * @param v2 * @return the similarity between v1 and v2 */ public double googleSimilarity(float[] v1, float[] v2) { return dotProduct(v1, v2); } private double googleSimilarity(String t1, String t2) { return googleSimilarity(getVector(t1), getVector(t2)); } /** * * @param terms * @param term * @return the similarity * @throws NoWordToVecVectorForTerm */ public double googleSimilarity(List<String> terms, String term) throws NoWordToVecVectorForTerm { return googleSimilarity(getGoogleNormedVector(terms), getVector(term)); } /** * * @param term * @return true if term is in vectors */ public boolean knownTerm(String term) { return vectors.containsKey(norm(term)); } private double magnitude(float[] v) { return Math.sqrt(IntStream.range(0, v.length).mapToDouble(i -> v[i] * v[i]).sum()); } private double magnitude(List<Float> v) { return Math.sqrt(v.stream().mapToDouble(i -> i * i).sum()); } /** * * @param v * @return normalized vector for v */ public float[] normVector(float[] v) { final float normed[] = new float[size]; double len = magnitude(v); IntStream.range(0, size) .forEach(i -> { normed[i] = v[i] / (float) len; }); return normed; } /** * * @param v * @return normalized vector for v */ public float[] normVector(List<Float> v) { final float normed[] = new float[v.size()]; double len = magnitude(v); IntStream.range(0, v.size()) .forEach(i -> { normed[i] = v.get(i) / (float) len; }); return normed; } /** * * @param s * @return List of Strings */ public List<String> stringToList(String s) { return Arrays.asList(s.split("\\s+")); } /** * * @param includeIf the predicate that is applied to the strings (the keys or embedded strings) * of the word to vec space to determine whether they should be retained in the output vector list * @return filtered vectors Map */ protected Map<String, float[]> filterVectors(Predicate<String> includeIf) { return vectors.entrySet().stream().filter(entry -> { return includeIf.test(entry.getKey()); }).collect(Collectors.toMap(Entry::getKey, Entry::getValue)); } /** * No Vector for Term * <p> * Exception to use check when a term looked up in the space has no known position */ public static class NoWordToVecVectorForTerm extends Exception { /** * * @param message */ public NoWordToVecVectorForTerm(String message) { super(message); } } }