package org.deeplearning4j.graph.models.deepwalk;
import org.deeplearning4j.graph.data.GraphLoader;
import org.deeplearning4j.graph.graph.Graph;
import org.deeplearning4j.graph.iterator.GraphWalkIterator;
import org.deeplearning4j.graph.iterator.RandomWalkIterator;
import org.deeplearning4j.graph.models.embeddings.InMemoryGraphLookupTable;
import org.junit.Before;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
import java.io.IOException;
import java.util.Arrays;
import static org.junit.Assert.*;
public class DeepWalkGradientCheck {
public static final double epsilon = 1e-6;
public static final double MAX_REL_ERROR = 1e-5;
@Before
public void before() {
DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE);
Nd4j.factory().setDType(DataBuffer.Type.DOUBLE);
}
@Test
public void checkGradients() throws IOException {
ClassPathResource cpr = new ClassPathResource("testgraph_7vertices.txt");
Graph<String, String> graph = GraphLoader
.loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 7, ",");
int vectorSize = 5;
int windowSize = 2;
Nd4j.getRandom().setSeed(12345);
DeepWalk<String, String> deepWalk = new DeepWalk.Builder<String, String>().learningRate(0.01)
.vectorSize(vectorSize).windowSize(windowSize).learningRate(0.01).build();
deepWalk.initialize(graph);
for (int i = 0; i < 7; i++) {
INDArray vector = deepWalk.getVertexVector(i);
assertArrayEquals(new int[] {1, vectorSize}, vector.shape());
System.out.println(Arrays.toString(vector.dup().data().asFloat()));
}
GraphWalkIterator<String> iter = new RandomWalkIterator<>(graph, 8);
deepWalk.fit(iter);
//Now, to check gradients:
InMemoryGraphLookupTable table = (InMemoryGraphLookupTable) deepWalk.lookupTable();
GraphHuffman tree = (GraphHuffman) table.getTree();
//For each pair of input/output vertices: check gradients
for (int i = 0; i < 7; i++) { //in
//First: check probabilities p(out|in)
double[] probs = new double[7];
double sumProb = 0.0;
for (int j = 0; j < 7; j++) {
probs[j] = table.calculateProb(i, j);
assertTrue(probs[j] >= 0.0 && probs[j] <= 1.0);
sumProb += probs[j];
}
assertTrue("Output probabilities do not sum to 1.0", Math.abs(sumProb - 1.0) < 1e-5);
for (int j = 0; j < 7; j++) { //out
//p(j|i)
int[] pathInnerNodes = tree.getPathInnerNodes(j);
//Calculate gradients:
INDArray[][] vecAndGrads = table.vectorsAndGradients(i, j);
assertEquals(2, vecAndGrads.length);
assertEquals(pathInnerNodes.length + 1, vecAndGrads[0].length);
assertEquals(pathInnerNodes.length + 1, vecAndGrads[1].length);
//Calculate gradients:
//Two types of gradients to test:
//(a) gradient of loss fn. wrt inner node vector representation
//(b) gradient of loss fn. wrt vector for input word
INDArray vertexVector = table.getVector(i);
//Check gradients for inner nodes:
for (int p = 0; p < pathInnerNodes.length; p++) {
int innerNodeIdx = pathInnerNodes[p];
INDArray innerNodeVector = table.getInnerNodeVector(innerNodeIdx);
INDArray innerNodeGrad = vecAndGrads[1][p + 1];
for (int v = 0; v < innerNodeVector.length(); v++) {
double backpropGradient = innerNodeGrad.getDouble(v);
double origParamValue = innerNodeVector.getDouble(v);
innerNodeVector.putScalar(v, origParamValue + epsilon);
double scorePlus = table.calculateScore(i, j);
innerNodeVector.putScalar(v, origParamValue - epsilon);
double scoreMinus = table.calculateScore(i, j);
innerNodeVector.putScalar(v, origParamValue); //reset param so it doesn't affect later calcs
double numericalGradient = (scorePlus - scoreMinus) / (2 * epsilon);
double relError;
if (backpropGradient == 0.0 && numericalGradient == 0.0)
relError = 0.0;
else {
relError = Math.abs(backpropGradient - numericalGradient)
/ (Math.abs(backpropGradient) + Math.abs(numericalGradient));
}
String msg = "innerNode grad: i=" + i + ", j=" + j + ", p=" + p + ", v=" + v + " - relError: "
+ relError + ", scorePlus=" + scorePlus + ", scoreMinus=" + scoreMinus
+ ", numGrad=" + numericalGradient + ", backpropGrad = " + backpropGradient;
if (relError > MAX_REL_ERROR)
fail(msg);
else
System.out.println(msg);
}
}
//Check gradients for input word vector:
INDArray vectorGrad = vecAndGrads[1][0];
assertArrayEquals(vectorGrad.shape(), vertexVector.shape());
for (int v = 0; v < vectorGrad.length(); v++) {
double backpropGradient = vectorGrad.getDouble(v);
double origParamValue = vertexVector.getDouble(v);
vertexVector.putScalar(v, origParamValue + epsilon);
double scorePlus = table.calculateScore(i, j);
vertexVector.putScalar(v, origParamValue - epsilon);
double scoreMinus = table.calculateScore(i, j);
vertexVector.putScalar(v, origParamValue);
double numericalGradient = (scorePlus - scoreMinus) / (2 * epsilon);
double relError;
if (backpropGradient == 0.0 && numericalGradient == 0.0)
relError = 0.0;
else {
relError = Math.abs(backpropGradient - numericalGradient)
/ (Math.abs(backpropGradient) + Math.abs(numericalGradient));
}
String msg = "vector grad: i=" + i + ", j=" + j + ", v=" + v + " - relError: " + relError
+ ", scorePlus=" + scorePlus + ", scoreMinus=" + scoreMinus + ", numGrad="
+ numericalGradient + ", backpropGrad = " + backpropGradient;
if (relError > MAX_REL_ERROR)
fail(msg);
else
System.out.println(msg);
}
System.out.println();
}
}
}
@Test
public void checkGradients2() throws IOException {
ClassPathResource cpr = new ClassPathResource("graph13.txt");
int nVertices = 13;
Graph<String, String> graph = GraphLoader
.loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 13, ",");
int vectorSize = 10;
int windowSize = 3;
Nd4j.getRandom().setSeed(12345);
DeepWalk<String, String> deepWalk = new DeepWalk.Builder<String, String>().learningRate(0.01)
.vectorSize(vectorSize).windowSize(windowSize).learningRate(0.01).build();
deepWalk.initialize(graph);
for (int i = 0; i < nVertices; i++) {
INDArray vector = deepWalk.getVertexVector(i);
assertArrayEquals(new int[] {1, vectorSize}, vector.shape());
System.out.println(Arrays.toString(vector.dup().data().asFloat()));
}
GraphWalkIterator<String> iter = new RandomWalkIterator<>(graph, 10);
deepWalk.fit(iter);
//Now, to check gradients:
InMemoryGraphLookupTable table = (InMemoryGraphLookupTable) deepWalk.lookupTable();
GraphHuffman tree = (GraphHuffman) table.getTree();
//For each pair of input/output vertices: check gradients
for (int i = 0; i < nVertices; i++) { //in
//First: check probabilities p(out|in)
double[] probs = new double[nVertices];
double sumProb = 0.0;
for (int j = 0; j < nVertices; j++) {
probs[j] = table.calculateProb(i, j);
assertTrue(probs[j] >= 0.0 && probs[j] <= 1.0);
sumProb += probs[j];
}
assertTrue("Output probabilities do not sum to 1.0 (i=" + i + "), sum=" + sumProb,
Math.abs(sumProb - 1.0) < 1e-5);
for (int j = 0; j < nVertices; j++) { //out
//p(j|i)
int[] pathInnerNodes = tree.getPathInnerNodes(j);
//Calculate gradients:
INDArray[][] vecAndGrads = table.vectorsAndGradients(i, j);
assertEquals(2, vecAndGrads.length);
assertEquals(pathInnerNodes.length + 1, vecAndGrads[0].length);
assertEquals(pathInnerNodes.length + 1, vecAndGrads[1].length);
//Calculate gradients:
//Two types of gradients to test:
//(a) gradient of loss fn. wrt inner node vector representation
//(b) gradient of loss fn. wrt vector for input word
INDArray vertexVector = table.getVector(i);
//Check gradients for inner nodes:
for (int p = 0; p < pathInnerNodes.length; p++) {
int innerNodeIdx = pathInnerNodes[p];
INDArray innerNodeVector = table.getInnerNodeVector(innerNodeIdx);
INDArray innerNodeGrad = vecAndGrads[1][p + 1];
for (int v = 0; v < innerNodeVector.length(); v++) {
double backpropGradient = innerNodeGrad.getDouble(v);
double origParamValue = innerNodeVector.getDouble(v);
innerNodeVector.putScalar(v, origParamValue + epsilon);
double scorePlus = table.calculateScore(i, j);
innerNodeVector.putScalar(v, origParamValue - epsilon);
double scoreMinus = table.calculateScore(i, j);
innerNodeVector.putScalar(v, origParamValue); //reset param so it doesn't affect later calcs
double numericalGradient = (scorePlus - scoreMinus) / (2 * epsilon);
double relError;
if (backpropGradient == 0.0 && numericalGradient == 0.0)
relError = 0.0;
else {
relError = Math.abs(backpropGradient - numericalGradient)
/ (Math.abs(backpropGradient) + Math.abs(numericalGradient));
}
String msg = "innerNode grad: i=" + i + ", j=" + j + ", p=" + p + ", v=" + v + " - relError: "
+ relError + ", scorePlus=" + scorePlus + ", scoreMinus=" + scoreMinus
+ ", numGrad=" + numericalGradient + ", backpropGrad = " + backpropGradient;
if (relError > MAX_REL_ERROR)
fail(msg);
else
System.out.println(msg);
}
}
//Check gradients for input word vector:
INDArray vectorGrad = vecAndGrads[1][0];
assertArrayEquals(vectorGrad.shape(), vertexVector.shape());
for (int v = 0; v < vectorGrad.length(); v++) {
double backpropGradient = vectorGrad.getDouble(v);
double origParamValue = vertexVector.getDouble(v);
vertexVector.putScalar(v, origParamValue + epsilon);
double scorePlus = table.calculateScore(i, j);
vertexVector.putScalar(v, origParamValue - epsilon);
double scoreMinus = table.calculateScore(i, j);
vertexVector.putScalar(v, origParamValue);
double numericalGradient = (scorePlus - scoreMinus) / (2 * epsilon);
double relError;
if (backpropGradient == 0.0 && numericalGradient == 0.0)
relError = 0.0;
else {
relError = Math.abs(backpropGradient - numericalGradient)
/ (Math.abs(backpropGradient) + Math.abs(numericalGradient));
}
String msg = "vector grad: i=" + i + ", j=" + j + ", v=" + v + " - relError: " + relError
+ ", scorePlus=" + scorePlus + ", scoreMinus=" + scoreMinus + ", numGrad="
+ numericalGradient + ", backpropGrad = " + backpropGradient;
if (relError > MAX_REL_ERROR)
fail(msg);
else
System.out.println(msg);
}
System.out.println();
}
}
}
private static boolean getBit(long in, int bitNum) {
long mask = 1L << bitNum;
return (in & mask) != 0L;
}
}