package org.deeplearning4j.graph.models.embeddings; import lombok.AllArgsConstructor; import lombok.NoArgsConstructor; import org.deeplearning4j.berkeley.Pair; import org.deeplearning4j.graph.api.IGraph; import org.deeplearning4j.graph.api.Vertex; import org.deeplearning4j.graph.models.GraphVectors; import org.nd4j.linalg.api.blas.Level1; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.ops.transforms.Transforms; import java.util.Comparator; import java.util.PriorityQueue; /** Base implementation for GraphVectors. Used in DeepWalk, and also when loading * graph vectors from file. */ @AllArgsConstructor @NoArgsConstructor public class GraphVectorsImpl<V, E> implements GraphVectors<V, E> { protected IGraph<V, E> graph; protected GraphVectorLookupTable lookupTable; @Override public IGraph<V, E> getGraph() { return graph; } @Override public int numVertices() { return lookupTable.getNumVertices(); } @Override public int getVectorSize() { return lookupTable.vectorSize(); } @Override public INDArray getVertexVector(Vertex<V> vertex) { return lookupTable.getVector(vertex.vertexID()); } @Override public INDArray getVertexVector(int vertexIdx) { return lookupTable.getVector(vertexIdx); } @Override public int[] verticesNearest(int vertexIdx, int top) { INDArray vec = lookupTable.getVector(vertexIdx).dup(); double norm2 = vec.norm2Number().doubleValue(); PriorityQueue<Pair<Double, Integer>> pq = new PriorityQueue<>(lookupTable.getNumVertices(), new PairComparator()); Level1 l1 = Nd4j.getBlasWrapper().level1(); for (int i = 0; i < numVertices(); i++) { if (i == vertexIdx) continue; INDArray other = lookupTable.getVector(i); double cosineSim = l1.dot(vec.length(), 1.0, vec, other) / (norm2 * other.norm2Number().doubleValue()); pq.add(new Pair<>(cosineSim, i)); } int[] out = new int[top]; for (int i = 0; i < top; i++) { out[i] = pq.remove().getSecond(); } return out; } private static class PairComparator implements Comparator<Pair<Double, Integer>> { @Override public int compare(Pair<Double, Integer> o1, Pair<Double, Integer> o2) { return -Double.compare(o1.getFirst(), o2.getFirst()); } } /**Returns the cosine similarity of the vector representations of two vertices in the graph * @return Cosine similarity of two vertices */ @Override public double similarity(Vertex<V> vertex1, Vertex<V> vertex2) { return similarity(vertex1.vertexID(), vertex2.vertexID()); } /**Returns the cosine similarity of the vector representations of two vertices in the graph, * given the indices of these verticies * @return Cosine similarity of two vertices */ @Override public double similarity(int vertexIdx1, int vertexIdx2) { if (vertexIdx1 == vertexIdx2) return 1.0; INDArray vector = Transforms.unitVec(getVertexVector(vertexIdx1)); INDArray vector2 = Transforms.unitVec(getVertexVector(vertexIdx2)); return Nd4j.getBlasWrapper().dot(vector, vector2); } }