package org.deeplearning4j.examples.recurrent.seq2seq; import org.deeplearning4j.nn.graph.ComputationGraph; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; /** * Created by susaneraly on 1/11/17. */ /* Note this is a helper class with methods to step through the decoder, one time step at a time. This process is common to all seq2seq models and will eventually be wrapped in a class in dl4j (along with an easier API). Track issue: https://github.com/deeplearning4j/deeplearning4j/issues/2635 */ public class Seq2SeqPredicter { private ComputationGraph net; private INDArray decoderInputTemplate = null; public Seq2SeqPredicter(ComputationGraph net) { this.net = net; } /* Given an input to the computation graph (which is expected to a be a seq2seq model) Predict the output given the encoder input (which is fixed) + the first time step from the decoder input All other time steps in the decoder input will be ignored */ public INDArray output(MultiDataSet testSet) { if (testSet.getFeatures()[0].size(0) > 2) { return output(testSet, false); } else { return output(testSet, true); } } public INDArray output(MultiDataSet testSet, boolean print) { INDArray correctOutput = testSet.getLabels()[0]; INDArray ret = Nd4j.zeros(correctOutput.shape()); decoderInputTemplate = testSet.getFeatures()[1].dup(); int currentStepThrough = 0; int stepThroughs = correctOutput.size(2)-1; while (currentStepThrough < stepThroughs) { if (print) { System.out.println("In time step "+currentStepThrough); System.out.println("\tEncoder input and Decoder input:"); System.out.println(CustomSequenceIterator.mapToString(testSet.getFeatures()[0],decoderInputTemplate, " + ")); } ret = stepOnce(testSet, currentStepThrough); if (print) { System.out.println("\tDecoder output:"); System.out.println("\t"+String.join("\n\t",CustomSequenceIterator.oneHotDecode(ret))); } currentStepThrough++; } ret = net.output(false,testSet.getFeatures()[0],decoderInputTemplate)[0]; if (print) { System.out.println("Final time step "+currentStepThrough); System.out.println("\tEncoder input and Decoder input:"); System.out.println(CustomSequenceIterator.mapToString(testSet.getFeatures()[0],decoderInputTemplate, " + ")); System.out.println("\tDecoder output:"); System.out.println("\t"+String.join("\n\t",CustomSequenceIterator.oneHotDecode(ret))); } return ret; } /* Will do a forward pass through encoder + decoder with the given input Updates the decoder input template from time = 1 to time t=n+1; Returns the output from this forward pass */ private INDArray stepOnce(MultiDataSet testSet, int n) { INDArray currentOutput = net.output(false, testSet.getFeatures()[0], decoderInputTemplate)[0]; copyTimeSteps(n,currentOutput,decoderInputTemplate); return currentOutput; } /* Copies timesteps time = 0 to time = t in "fromArr" to time = 1 to time = t+1 in "toArr" */ private void copyTimeSteps(int t, INDArray fromArr, INDArray toArr) { INDArray fromView = fromArr.get(NDArrayIndex.all(),NDArrayIndex.all(),NDArrayIndex.interval(0,t,true)); INDArray toView = toArr.get(NDArrayIndex.all(),NDArrayIndex.all(),NDArrayIndex.interval(1,t+1,true)); toView.assign(fromView.dup()); } }