package org.deeplearning4j.graph.graph; import org.apache.commons.lang3.ArrayUtils; import org.deeplearning4j.graph.api.*; import org.deeplearning4j.graph.data.GraphLoader; import org.deeplearning4j.graph.iterator.RandomWalkIterator; import org.deeplearning4j.graph.iterator.WeightedRandomWalkIterator; import org.deeplearning4j.graph.vertexfactory.VertexFactory; import org.junit.Test; import org.nd4j.linalg.io.ClassPathResource; import java.util.Arrays; import java.util.HashSet; import java.util.List; import java.util.Set; import static junit.framework.TestCase.assertTrue; import static org.junit.Assert.*; public class TestGraph { @Test public void testSimpleGraph() { Graph<String, String> graph = new Graph<>(10, false, new VFactory()); assertEquals(10, graph.numVertices()); for (int i = 0; i < 10; i++) { //Add some undirected edges String str = i + "--" + (i + 1) % 10; Edge<String> edge = new Edge<>(i, (i + 1) % 10, str, false); graph.addEdge(edge); } for (int i = 0; i < 10; i++) { List<Edge<String>> edges = graph.getEdgesOut(i); assertEquals(2, edges.size()); //expect for example 0->1 and 9->0 Edge<String> first = edges.get(0); if (first.getFrom() == i) { //undirected edge: i -> i+1 (or 9 -> 0) assertEquals(i, first.getFrom()); assertEquals((i + 1) % 10, first.getTo()); } else { //undirected edge: i-1 -> i (or 9 -> 0) assertEquals((i + 10 - 1) % 10, first.getFrom()); assertEquals(i, first.getTo()); } Edge<String> second = edges.get(1); assertNotEquals(first.getFrom(), second.getFrom()); if (second.getFrom() == i) { //undirected edge: i -> i+1 (or 9 -> 0) assertEquals(i, second.getFrom()); assertEquals((i + 1) % 10, second.getTo()); } else { //undirected edge: i-1 -> i (or 9 -> 0) assertEquals((i + 10 - 1) % 10, second.getFrom()); assertEquals(i, second.getTo()); } } } private static class VFactory implements VertexFactory<String> { @Override public Vertex<String> create(int vertexIdx) { return new Vertex<>(vertexIdx, String.valueOf(vertexIdx)); } } @Test public void testRandomWalkIterator() { Graph<String, String> graph = new Graph<>(10, false, new VFactory()); assertEquals(10, graph.numVertices()); for (int i = 0; i < 10; i++) { //Add some undirected edges String str = i + "--" + (i + 1) % 10; Edge<String> edge = new Edge<>(i, (i + 1) % 10, str, false); graph.addEdge(edge); } int walkLength = 4; RandomWalkIterator<String> iter = new RandomWalkIterator<>(graph, walkLength, 1235, NoEdgeHandling.EXCEPTION_ON_DISCONNECTED); int count = 0; Set<Integer> startIdxSet = new HashSet<>(); while (iter.hasNext()) { count++; IVertexSequence<String> sequence = iter.next(); int seqCount = 1; int first = sequence.next().vertexID(); int previous = first; while (sequence.hasNext()) { //Possible next vertices for this particular graph: (previous+1)%10 or (previous-1+10)%10 int left = (previous - 1 + 10) % 10; int right = (previous + 1) % 10; int current = sequence.next().vertexID(); assertTrue("expected: " + left + " or " + right + ", got " + current, current == left || current == right); seqCount++; previous = current; } assertEquals(seqCount, walkLength + 1); //walk of 0 -> 1 element, walk of 2 -> 3 elements etc assertFalse(startIdxSet.contains(first)); //Expect to see each node exactly once startIdxSet.add(first); } assertEquals(10, count); //Expect exactly 10 starting nodes assertEquals(10, startIdxSet.size()); } @Test public void testWeightedRandomWalkIterator() throws Exception { //Load a directed, weighted graph from file String path = new ClassPathResource("WeightedGraph.txt").getTempFileFromArchive().getAbsolutePath(); int numVertices = 9; String delim = ","; String[] ignoreLinesStartingWith = new String[] {"//"}; //Comment lines start with "//" IGraph<String, Double> graph = GraphLoader.loadWeightedEdgeListFile(path, numVertices, delim, true, ignoreLinesStartingWith); assertEquals(numVertices, graph.numVertices()); int[] vertexOutDegrees = {2, 2, 1, 2, 2, 1, 1, 1, 1}; for (int i = 0; i < numVertices; i++) assertEquals(vertexOutDegrees[i], graph.getVertexDegree(i)); int[][] edges = new int[][] {{1, 3}, //0->1 and 1->3 {2, 4}, //1->2 and 1->4 {5}, //etc {4, 6}, {5, 7}, {8}, {7}, {8}, {0}}; double[][] edgeWeights = new double[][] {{1, 3}, {12, 14}, {25}, {34, 36}, {45, 47}, {58}, {67}, {78}, {80}}; double[][] edgeWeightsNormalized = new double[edgeWeights.length][0]; for (int i = 0; i < edgeWeights.length; i++) { double sum = 0.0; for (int j = 0; j < edgeWeights[i].length; j++) sum += edgeWeights[i][j]; edgeWeightsNormalized[i] = new double[edgeWeights[i].length]; for (int j = 0; j < edgeWeights[i].length; j++) edgeWeightsNormalized[i][j] = edgeWeights[i][j] / sum; } int walkLength = 5; WeightedRandomWalkIterator<String> iterator = new WeightedRandomWalkIterator<>(graph, walkLength, 12345); int walkCount = 0; Set<Integer> set = new HashSet<>(); while (iterator.hasNext()) { IVertexSequence<String> walk = iterator.next(); assertEquals(walkLength + 1, walk.sequenceLength()); //Walk length of 5 -> 6 vertices (inc starting point) int thisWalkCount = 0; boolean first = true; int lastVertex = -1; while (walk.hasNext()) { Vertex<String> vertex = walk.next(); if (first) { assertFalse(set.contains(vertex.vertexID())); set.add(vertex.vertexID()); lastVertex = vertex.vertexID(); first = false; } else { //Ensure that a directed edge exists from lastVertex -> vertex int currVertex = vertex.vertexID(); assertTrue(ArrayUtils.contains(edges[lastVertex], currVertex)); lastVertex = currVertex; } thisWalkCount++; } assertEquals(walkLength + 1, thisWalkCount); //Walk length of 5 -> 6 vertices (inc starting point) walkCount++; } double[][] transitionProb = new double[numVertices][numVertices]; int nWalks = 2000; for (int i = 0; i < nWalks; i++) { iterator.reset(); while (iterator.hasNext()) { IVertexSequence<String> seq = iterator.next(); int last = -1; while (seq.hasNext()) { int curr = seq.next().vertexID(); if (last != -1) { transitionProb[last][curr] += 1.0; } last = curr; } } } for (int i = 0; i < transitionProb.length; i++) { double sum = 0.0; for (int j = 0; j < transitionProb[i].length; j++) sum += transitionProb[i][j]; for (int j = 0; j < transitionProb[i].length; j++) transitionProb[i][j] /= sum; System.out.println(Arrays.toString(transitionProb[i])); } //Check that transition probs are essentially correct (within bounds of random variation) for (int i = 0; i < numVertices; i++) { for (int j = 0; j < numVertices; j++) { if (!ArrayUtils.contains(edges[i], j)) { assertEquals(0.0, transitionProb[i][j], 0.0); } else { int idx = ArrayUtils.indexOf(edges[i], j); assertEquals(edgeWeightsNormalized[i][idx], transitionProb[i][j], 0.01); } } } for (int i = 0; i < numVertices; i++) assertTrue(set.contains(i)); assertEquals(numVertices, walkCount); } }