package org.deeplearning4j.iterator;
import org.datavec.api.util.ClassPathResource;
import org.deeplearning4j.iterator.provider.CollectionLabeledSentenceProvider;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.*;
/**
* Created by Alex on 28/01/2017.
*/
public class TestCnnSentenceDataSetIterator {
@Test
public void testSentenceIterator() throws Exception {
WordVectors w2v = WordVectorSerializer
.readWord2VecModel(new ClassPathResource("word2vec/googleload/sample_vec.bin").getFile());
int vectorSize = w2v.lookupTable().layerSize();
// Collection<String> words = w2v.lookupTable().getVocabCache().words();
// for(String s : words){
// System.out.println(s);
// }
List<String> sentences = new ArrayList<>();
//First word: all present
sentences.add("these balance Database model");
sentences.add("into same THISWORDDOESNTEXIST are");
int maxLength = 4;
List<String> s1 = Arrays.asList("these", "balance", "Database", "model");
List<String> s2 = Arrays.asList("into", "same", "are");
List<String> labelsForSentences = Arrays.asList("Positive", "Negative");
INDArray expLabels = Nd4j.create(new double[][] {{0, 1}, {1, 0}}); //Order of labels: alphabetic. Positive -> [0,1]
boolean[] alongHeightVals = new boolean[] {true, false};
for (boolean alongHeight : alongHeightVals) {
INDArray expectedFeatures;
if (alongHeight) {
expectedFeatures = Nd4j.create(2, 1, maxLength, vectorSize);
} else {
expectedFeatures = Nd4j.create(2, 1, vectorSize, maxLength);
}
INDArray expectedFeatureMask = Nd4j.create(new double[][] {{1, 1, 1, 1}, {1, 1, 1, 0}});
for (int i = 0; i < 4; i++) {
if (alongHeight) {
expectedFeatures.get(NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.point(i),
NDArrayIndex.all()).assign(w2v.getWordVectorMatrix(s1.get(i)));
} else {
expectedFeatures.get(NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.point(i)).assign(w2v.getWordVectorMatrix(s1.get(i)));
}
}
for (int i = 0; i < 3; i++) {
if (alongHeight) {
expectedFeatures.get(NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.point(i),
NDArrayIndex.all()).assign(w2v.getWordVectorMatrix(s2.get(i)));
} else {
expectedFeatures.get(NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(),
NDArrayIndex.point(i)).assign(w2v.getWordVectorMatrix(s2.get(i)));
}
}
LabeledSentenceProvider p = new CollectionLabeledSentenceProvider(sentences, labelsForSentences, null);
CnnSentenceDataSetIterator dsi = new CnnSentenceDataSetIterator.Builder().sentenceProvider(p)
.wordVectors(w2v).maxSentenceLength(256).minibatchSize(32).sentencesAlongHeight(alongHeight)
.build();
// System.out.println("alongHeight = " + alongHeight);
DataSet ds = dsi.next();
assertArrayEquals(expectedFeatures.shape(), ds.getFeatures().shape());
assertEquals(expectedFeatures, ds.getFeatures());
assertEquals(expLabels, ds.getLabels());
assertEquals(expectedFeatureMask, ds.getFeaturesMaskArray());
assertNull(ds.getLabelsMaskArray());
INDArray s1F = dsi.loadSingleSentence(sentences.get(0));
INDArray s2F = dsi.loadSingleSentence(sentences.get(1));
INDArray sub1 = ds.getFeatures().get(NDArrayIndex.interval(0, 0, true), NDArrayIndex.all(),
NDArrayIndex.all(), NDArrayIndex.all());
INDArray sub2;
if (alongHeight) {
sub2 = ds.getFeatures().get(NDArrayIndex.interval(1, 1, true), NDArrayIndex.all(),
NDArrayIndex.interval(0, 3), NDArrayIndex.all());
} else {
sub2 = ds.getFeatures().get(NDArrayIndex.interval(1, 1, true), NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.interval(0, 3));
}
assertArrayEquals(sub1.shape(), s1F.shape());
assertArrayEquals(sub2.shape(), s2F.shape());
assertEquals(sub1, s1F);
assertEquals(sub2, s2F);
}
}
}