package org.deeplearning4j.graph.models.deepwalk;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.graph.api.Edge;
import org.deeplearning4j.graph.api.IGraph;
import org.deeplearning4j.graph.data.GraphLoader;
import org.deeplearning4j.graph.graph.Graph;
import org.deeplearning4j.graph.iterator.GraphWalkIterator;
import org.deeplearning4j.graph.iterator.RandomWalkIterator;
import org.deeplearning4j.graph.iterator.parallel.GraphWalkIteratorProvider;
import org.deeplearning4j.graph.iterator.parallel.WeightedRandomWalkGraphIteratorProvider;
import org.deeplearning4j.graph.models.GraphVectors;
import org.deeplearning4j.graph.models.loader.GraphVectorSerializer;
import org.deeplearning4j.graph.vertexfactory.StringVertexFactory;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;
import static org.junit.Assert.*;
public class TestDeepWalk {
@Test
public void testBasic() throws IOException {
//Very basic test. Load graph, build tree, call fit, make sure it doesn't throw any exceptions
ClassPathResource cpr = new ClassPathResource("testgraph_7vertices.txt");
Graph<String, String> graph = GraphLoader
.loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 7, ",");
int vectorSize = 5;
int windowSize = 2;
DeepWalk<String, String> deepWalk = new DeepWalk.Builder<String, String>().learningRate(0.01)
.vectorSize(vectorSize).windowSize(windowSize).learningRate(0.01).build();
deepWalk.initialize(graph);
for (int i = 0; i < 7; i++) {
INDArray vector = deepWalk.getVertexVector(i);
assertArrayEquals(new int[] {1, vectorSize}, vector.shape());
System.out.println(Arrays.toString(vector.dup().data().asFloat()));
}
GraphWalkIterator<String> iter = new RandomWalkIterator<>(graph, 8);
deepWalk.fit(iter);
for (int t = 0; t < 5; t++) {
iter.reset();
deepWalk.fit(iter);
System.out.println("--------------------");
for (int i = 0; i < 7; i++) {
INDArray vector = deepWalk.getVertexVector(i);
assertArrayEquals(new int[] {1, vectorSize}, vector.shape());
System.out.println(Arrays.toString(vector.dup().data().asFloat()));
}
}
}
@Test
public void testParallel() {
IGraph<String, String> graph = generateRandomGraph(1000, 10);
int vectorSize = 20;
int windowSize = 2;
DeepWalk<String, String> deepWalk = new DeepWalk.Builder<String, String>().learningRate(0.01)
.vectorSize(vectorSize).windowSize(windowSize).learningRate(0.01).build();
deepWalk.initialize(graph);
deepWalk.fit(graph, 8);
}
private static Graph<String, String> generateRandomGraph(int nVertices, int nEdgesPerVertex) {
Random r = new Random(12345);
Graph<String, String> graph = new Graph<>(nVertices, new StringVertexFactory());
for (int i = 0; i < nVertices; i++) {
for (int j = 0; j < nEdgesPerVertex; j++) {
int to = r.nextInt(nVertices);
Edge<String> edge = new Edge<>(i, to, i + "--" + to, false);
graph.addEdge(edge);
}
}
return graph;
}
@Test
public void testVerticesNearest() {
int nVertices = 20;
Graph<String, String> graph = generateRandomGraph(nVertices, 8);
int vectorSize = 5;
int windowSize = 2;
DeepWalk<String, String> deepWalk = new DeepWalk.Builder<String, String>().learningRate(0.01)
.vectorSize(vectorSize).windowSize(windowSize).learningRate(0.01).build();
deepWalk.initialize(graph);
deepWalk.fit(graph, 10);
int topN = 5;
int nearestTo = 4;
int[] nearest = deepWalk.verticesNearest(nearestTo, topN);
double[] cosSim = new double[topN];
double minSimNearest = 1;
for (int i = 0; i < topN; i++) {
cosSim[i] = deepWalk.similarity(nearest[i], nearestTo);
minSimNearest = Math.min(minSimNearest, cosSim[i]);
if (i > 0)
assertTrue(cosSim[i] <= cosSim[i - 1]);
}
for (int i = 0; i < nVertices; i++) {
if (i == nearestTo)
continue;
boolean skip = false;
for (int j = 0; j < nearest.length; j++) {
if (i == nearest[j]) {
skip = true;
continue;
}
}
if (skip)
continue;
double sim = deepWalk.similarity(i, nearestTo);
System.out.println(i + "\t" + nearestTo + "\t" + sim);
assertTrue(sim <= minSimNearest);
}
}
@Test
public void testLoadingSaving() throws IOException {
String out = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "dl4jdwtestout.txt");
int nVertices = 20;
Graph<String, String> graph = generateRandomGraph(nVertices, 8);
int vectorSize = 5;
int windowSize = 2;
DeepWalk<String, String> deepWalk = new DeepWalk.Builder<String, String>().learningRate(0.01)
.vectorSize(vectorSize).windowSize(windowSize).learningRate(0.01).build();
deepWalk.initialize(graph);
deepWalk.fit(graph, 10);
GraphVectorSerializer.writeGraphVectors(deepWalk, out);
GraphVectors<String, String> vectors =
(GraphVectors<String, String>) GraphVectorSerializer.loadTxtVectors(new File(out));
assertEquals(deepWalk.numVertices(), vectors.numVertices());
assertEquals(deepWalk.getVectorSize(), vectors.getVectorSize());
for (int i = 0; i < nVertices; i++) {
INDArray vecDW = deepWalk.getVertexVector(i);
INDArray vecLoaded = vectors.getVertexVector(i);
for (int j = 0; j < vectorSize; j++) {
double d1 = vecDW.getDouble(j);
double d2 = vecLoaded.getDouble(j);
double relError = Math.abs(d1 - d2) / (Math.abs(d1) + Math.abs(d2));
assertTrue(relError < 1e-6);
}
}
}
@Test
public void testDeepWalk13Vertices() throws IOException {
int nVertices = 13;
ClassPathResource cpr = new ClassPathResource("graph13.txt");
Graph<String, String> graph = GraphLoader
.loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 13, ",");
System.out.println(graph);
Nd4j.getRandom().setSeed(12345);
int nEpochs = 200;
//Set up network
DeepWalk<String, String> deepWalk =
new DeepWalk.Builder<String, String>().vectorSize(50).windowSize(4).seed(12345).build();
//Run learning
for (int i = 0; i < nEpochs; i++) {
deepWalk.setLearningRate(0.03 / nEpochs * (nEpochs - i));
deepWalk.fit(graph, 10);
}
//Calculate similarity(0,i)
for (int i = 0; i < nVertices; i++) {
System.out.println(deepWalk.similarity(0, i));
}
for (int i = 0; i < nVertices; i++)
System.out.println(deepWalk.getVertexVector(i));
}
@Test
public void testDeepWalkWeightedParallel() throws IOException {
//Load graph
String path = new ClassPathResource("WeightedGraph.txt").getTempFileFromArchive().getAbsolutePath();
int numVertices = 9;
String delim = ",";
String[] ignoreLinesStartingWith = new String[] {"//"}; //Comment lines start with "//"
IGraph<String, Double> graph =
GraphLoader.loadWeightedEdgeListFile(path, numVertices, delim, true, ignoreLinesStartingWith);
//Set up DeepWalk
int vectorSize = 5;
int windowSize = 2;
DeepWalk<String, Double> deepWalk = new DeepWalk.Builder<String, Double>().learningRate(0.01)
.vectorSize(vectorSize).windowSize(windowSize).learningRate(0.01).build();
deepWalk.initialize(graph);
//Can't use the following method here: defaults to unweighted random walk
//deepWalk.fit(graph, 10); //Unweighted random walk
//Create GraphWalkIteratorProvider. The GraphWalkIteratorProvider is used to create multiple GraphWalkIterator objects.
//Here, it is used to create a GraphWalkIterator, one for each thread
int walkLength = 5;
GraphWalkIteratorProvider<String> iteratorProvider =
new WeightedRandomWalkGraphIteratorProvider<>(graph, walkLength);
//Fit in parallel
deepWalk.fit(iteratorProvider);
}
}