/* * Copyright 2014 Radialpoint SafeCare Inc. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.radialpoint.word2vec; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Set; import java.util.TreeSet; public class Distance { public static class ScoredTerm { private String term; private float score; public ScoredTerm(String term, float score) { super(); this.term = term; this.score = score; } public String getTerm() { return term; } public float getScore() { return score; } } public static List<ScoredTerm> measure(Vectors vectors, int wordsToReturn, String[] tokens) throws OutOfVocabularyException { double distance, length; float[] bestDistance = new float[wordsToReturn]; String[] bestWords = new String[wordsToReturn]; int d; int size = vectors.vectorSize(); float[] vec = new float[size]; float[][]allVec = vectors.getVectors(); Set<Integer> wordIdx = new TreeSet<Integer>(); int tokenCount = tokens.length; boolean outOfDict = false; String outOfDictWord = null; Arrays.fill(vec, 0.0f); wordIdx.clear(); for (int i = 0; i < tokenCount; i++) { Integer idx = vectors.getIndexOrNull(tokens[i]); if (idx == null) { outOfDictWord = tokens[i]; outOfDict = true; break; } wordIdx.add(idx); float[] vect1 = allVec[idx]; for (int j = 0; j < size; j++) vec[j] += vect1[j]; } if (outOfDict) throw new OutOfVocabularyException(outOfDictWord); length = 0; for (int i = 0; i < size; i++) length += vec[i] * vec[i]; length = (float) Math.sqrt(length); for (int i = 0; i < size; i++) vec[i] /= length; for (int i = 0; i < wordsToReturn; i++) { bestDistance[i] = Float.MIN_VALUE; bestWords[i] = ""; } for (int c = 0; c < vectors.wordCount(); c++) { if (wordIdx.contains(c)) continue; distance = 0; for (int i = 0; i < size; i++) distance += vec[i] * allVec[c][i]; for (int i = 0; i < wordsToReturn; i++) { if (distance > bestDistance[i]) { for (d = wordsToReturn - 1; d > i; d--) { bestDistance[d] = bestDistance[d - 1]; bestWords[d] = bestWords[d - 1]; } bestDistance[i] = (float) distance; bestWords[i] = vectors.getTerm(c); break; } } } List<ScoredTerm> result = new ArrayList<ScoredTerm>(wordsToReturn); for (int i = 0; i < wordsToReturn; i++) result.add(new ScoredTerm(bestWords[i], bestDistance[i])); return result; } }