package org.deeplearning4j.graph.models.deepwalk; import lombok.AllArgsConstructor; import org.deeplearning4j.graph.api.IGraph; import org.deeplearning4j.graph.api.IVertexSequence; import org.deeplearning4j.graph.api.NoEdgeHandling; import org.deeplearning4j.graph.iterator.GraphWalkIterator; import org.deeplearning4j.graph.iterator.parallel.GraphWalkIteratorProvider; import org.deeplearning4j.graph.iterator.parallel.RandomWalkGraphIteratorProvider; import org.deeplearning4j.graph.models.embeddings.GraphVectorLookupTable; import org.deeplearning4j.graph.models.embeddings.GraphVectorsImpl; import org.deeplearning4j.graph.models.embeddings.InMemoryGraphLookupTable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.ArrayList; import java.util.List; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicLong; /**Implementation of the DeepWalk graph vectorization model, based on the paper * <i>DeepWalk: Online Learning of Social Representations</i> by Perozzi, Al-Rfou & Skiena (2014), * <a href="http://arxiv.org/abs/1403.6652">http://arxiv.org/abs/1403.6652</a><br> * Similar to word2vec in nature, DeepWalk is an unsupervised learning algorithm that learns a vector representation * of each vertex in a graph. Vector representations are learned using walks (usually random walks) on the vertices in * the graph.<br> * Once learned, these vector representations can then be used for purposes such as classification, clustering, similarity * search, etc on the graph<br> * @author Alex Black */ public class DeepWalk<V, E> extends GraphVectorsImpl<V, E> { public static final int STATUS_UPDATE_FREQUENCY = 1000; private Logger log = LoggerFactory.getLogger(DeepWalk.class); private int vectorSize; private int windowSize; private double learningRate; private boolean initCalled = false; private long seed; private ExecutorService executorService; private int nThreads = Runtime.getRuntime().availableProcessors(); private transient AtomicLong walkCounter = new AtomicLong(0); public DeepWalk() { } public int getVectorSize() { return vectorSize; } public int getWindowSize() { return windowSize; } public double getLearningRate() { return learningRate; } public void setLearningRate(double learningRate) { this.learningRate = learningRate; if (lookupTable != null) lookupTable.setLearningRate(learningRate); } /** Initialize the DeepWalk model with a given graph. */ public void initialize(IGraph<V, E> graph) { int nVertices = graph.numVertices(); int[] degrees = new int[nVertices]; for (int i = 0; i < nVertices; i++) degrees[i] = graph.getVertexDegree(i); initialize(degrees); } /** Initialize the DeepWalk model with a list of vertex degrees for a graph.<br> * Specifically, graphVertexDegrees[i] represents the vertex degree of the ith vertex<br> * vertex degrees are used to construct a binary (Huffman) tree, which is in turn used in * the hierarchical softmax implementation * @param graphVertexDegrees degrees of each vertex */ public void initialize(int[] graphVertexDegrees) { log.info("Initializing: Creating Huffman tree and lookup table..."); GraphHuffman gh = new GraphHuffman(graphVertexDegrees.length); gh.buildTree(graphVertexDegrees); lookupTable = new InMemoryGraphLookupTable(graphVertexDegrees.length, vectorSize, gh, learningRate); initCalled = true; log.info("Initialization complete"); } /** Fit the model, in parallel. * This creates a set of GraphWalkIterators, which are then distributed one to each thread * @param graph Graph to fit * @param walkLength Length of rangom walks to generate */ public void fit(IGraph<V, E> graph, int walkLength) { if (!initCalled) initialize(graph); //First: create iterators, one for each thread GraphWalkIteratorProvider<V> iteratorProvider = new RandomWalkGraphIteratorProvider<>(graph, walkLength, seed, NoEdgeHandling.SELF_LOOP_ON_DISCONNECTED); fit(iteratorProvider); } /** Fit the model, in parallel, using a given GraphWalkIteratorProvider.<br> * This object is used to generate multiple GraphWalkIterators, which can then be distributed to each thread * to do in parallel<br> * Note that {@link #fit(IGraph, int)} will be more convenient in many cases<br> * Note that {@link #initialize(IGraph)} or {@link #initialize(int[])} <em>must</em> be called first. * @param iteratorProvider GraphWalkIteratorProvider * @see #fit(IGraph, int) */ public void fit(GraphWalkIteratorProvider<V> iteratorProvider) { if (!initCalled) throw new UnsupportedOperationException("DeepWalk not initialized (call initialize before fit)"); List<GraphWalkIterator<V>> iteratorList = iteratorProvider.getGraphWalkIterators(nThreads); executorService = Executors.newFixedThreadPool(nThreads, new ThreadFactory() { @Override public Thread newThread(Runnable r) { Thread t = new Thread(r); t.setDaemon(true); return t; } }); List<Future<Void>> list = new ArrayList<>(iteratorList.size()); //log.info("Fitting Graph with {} threads", Math.max(nThreads,iteratorList.size())); for (GraphWalkIterator<V> iter : iteratorList) { LearningCallable c = new LearningCallable(iter); list.add(executorService.submit(c)); } executorService.shutdown(); try { executorService.awaitTermination(999, TimeUnit.DAYS); } catch (InterruptedException e) { throw new RuntimeException("ExecutorService interrupted", e); } //Don't need to block on futures to get a value out, but we want to re-throw any exceptions encountered for (Future<Void> f : list) { try { f.get(); } catch (Exception e) { throw new RuntimeException(e); } } } /**Fit the DeepWalk model <b>using a single thread</b> using a given GraphWalkIterator. If parallel fitting is required, * {@link #fit(IGraph, int)} or {@link #fit(GraphWalkIteratorProvider)} should be used.<br> * Note that {@link #initialize(IGraph)} or {@link #initialize(int[])} <em>must</em> be called first. * * @param iterator iterator for graph walks */ public void fit(GraphWalkIterator<V> iterator) { if (!initCalled) throw new UnsupportedOperationException("DeepWalk not initialized (call initialize before fit)"); int walkLength = iterator.walkLength(); while (iterator.hasNext()) { IVertexSequence<V> sequence = iterator.next(); //Skipgram model: int[] walk = new int[walkLength + 1]; int i = 0; while (sequence.hasNext()) walk[i++] = sequence.next().vertexID(); skipGram(walk); long iter = walkCounter.incrementAndGet(); if (iter % STATUS_UPDATE_FREQUENCY == 0) { log.info("Processed {} random walks on graph", iter); } } } private void skipGram(int[] walk) { for (int mid = windowSize; mid < walk.length - windowSize; mid++) { for (int pos = mid - windowSize; pos <= mid + windowSize; pos++) { if (pos == mid) continue; //pair of vertices: walk[mid] -> walk[pos] lookupTable.iterate(walk[mid], walk[pos]); } } } public GraphVectorLookupTable lookupTable() { return lookupTable; } public static class Builder<V, E> { private int vectorSize = 100; private long seed = System.currentTimeMillis(); private double learningRate = 0.01; private int windowSize = 2; /** Sets the size of the vectors to be learned for each vertex in the graph */ public Builder<V, E> vectorSize(int vectorSize) { this.vectorSize = vectorSize; return this; } /** Set the learning rate */ public Builder<V, E> learningRate(double learningRate) { this.learningRate = learningRate; return this; } /** Sets the window size used in skipgram model */ public Builder<V, E> windowSize(int windowSize) { this.windowSize = windowSize; return this; } /** Seed for random number generation (used for repeatability). * Note however that parallel/async gradient descent might result in behaviour that * is not repeatable, in spite of setting seed */ public Builder<V, E> seed(long seed) { this.seed = seed; return this; } public DeepWalk<V, E> build() { DeepWalk<V, E> dw = new DeepWalk<>(); dw.vectorSize = vectorSize; dw.windowSize = windowSize; dw.learningRate = learningRate; dw.seed = seed; return dw; } } @AllArgsConstructor private class LearningCallable implements Callable<Void> { private final GraphWalkIterator<V> iterator; @Override public Void call() throws Exception { fit(iterator); return null; } } }