package org.deeplearning4j.examples.recurrent.character; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.DataSetPreProcessor; import org.nd4j.linalg.factory.Nd4j; import java.io.File; import java.io.IOException; import java.nio.charset.Charset; import java.nio.file.Files; import java.util.*; /** A simple DataSetIterator for use in the GravesLSTMCharModellingExample. * Given a text file and a few options, generate feature vectors and labels for training, * where we want to predict the next character in the sequence.<br> * This is done by randomly choosing a position in the text file, at offsets of 0, exampleLength, 2*exampleLength, etc * to start each sequence. Then we convert each character to an index, i.e., a one-hot vector. * Then the character 'a' becomes [1,0,0,0,...], 'b' becomes [0,1,0,0,...], etc * * Feature vectors and labels are both one-hot vectors of same length * @author Alex Black */ public class CharacterIterator implements DataSetIterator { //Valid characters private char[] validCharacters; //Maps each character to an index ind the input/output private Map<Character,Integer> charToIdxMap; //All characters of the input file (after filtering to only those that are valid private char[] fileCharacters; //Length of each example/minibatch (number of characters) private int exampleLength; //Size of each minibatch (number of examples) private int miniBatchSize; private Random rng; //Offsets for the start of each example private LinkedList<Integer> exampleStartOffsets = new LinkedList<>(); /** * @param textFilePath Path to text file to use for generating samples * @param textFileEncoding Encoding of the text file. Can try Charset.defaultCharset() * @param miniBatchSize Number of examples per mini-batch * @param exampleLength Number of characters in each input/output vector * @param validCharacters Character array of valid characters. Characters not present in this array will be removed * @param rng Random number generator, for repeatability if required * @throws IOException If text file cannot be loaded */ public CharacterIterator(String textFilePath, Charset textFileEncoding, int miniBatchSize, int exampleLength, char[] validCharacters, Random rng) throws IOException { if( !new File(textFilePath).exists()) throw new IOException("Could not access file (does not exist): " + textFilePath); if( miniBatchSize <= 0 ) throw new IllegalArgumentException("Invalid miniBatchSize (must be >0)"); this.validCharacters = validCharacters; this.exampleLength = exampleLength; this.miniBatchSize = miniBatchSize; this.rng = rng; //Store valid characters is a map for later use in vectorization charToIdxMap = new HashMap<>(); for( int i=0; i<validCharacters.length; i++ ) charToIdxMap.put(validCharacters[i], i); //Load file and convert contents to a char[] boolean newLineValid = charToIdxMap.containsKey('\n'); List<String> lines = Files.readAllLines(new File(textFilePath).toPath(),textFileEncoding); int maxSize = lines.size(); //add lines.size() to account for newline characters at end of each line for( String s : lines ) maxSize += s.length(); char[] characters = new char[maxSize]; int currIdx = 0; for( String s : lines ){ char[] thisLine = s.toCharArray(); for (char aThisLine : thisLine) { if (!charToIdxMap.containsKey(aThisLine)) continue; characters[currIdx++] = aThisLine; } if(newLineValid) characters[currIdx++] = '\n'; } if( currIdx == characters.length ){ fileCharacters = characters; } else { fileCharacters = Arrays.copyOfRange(characters, 0, currIdx); } if( exampleLength >= fileCharacters.length ) throw new IllegalArgumentException("exampleLength="+exampleLength +" cannot exceed number of valid characters in file ("+fileCharacters.length+")"); int nRemoved = maxSize - fileCharacters.length; System.out.println("Loaded and converted file: " + fileCharacters.length + " valid characters of " + maxSize + " total characters (" + nRemoved + " removed)"); initializeOffsets(); } /** A minimal character set, with a-z, A-Z, 0-9 and common punctuation etc */ public static char[] getMinimalCharacterSet(){ List<Character> validChars = new LinkedList<>(); for(char c='a'; c<='z'; c++) validChars.add(c); for(char c='A'; c<='Z'; c++) validChars.add(c); for(char c='0'; c<='9'; c++) validChars.add(c); char[] temp = {'!', '&', '(', ')', '?', '-', '\'', '"', ',', '.', ':', ';', ' ', '\n', '\t'}; for( char c : temp ) validChars.add(c); char[] out = new char[validChars.size()]; int i=0; for( Character c : validChars ) out[i++] = c; return out; } /** As per getMinimalCharacterSet(), but with a few extra characters */ public static char[] getDefaultCharacterSet(){ List<Character> validChars = new LinkedList<>(); for(char c : getMinimalCharacterSet() ) validChars.add(c); char[] additionalChars = {'@', '#', '$', '%', '^', '*', '{', '}', '[', ']', '/', '+', '_', '\\', '|', '<', '>'}; for( char c : additionalChars ) validChars.add(c); char[] out = new char[validChars.size()]; int i=0; for( Character c : validChars ) out[i++] = c; return out; } public char convertIndexToCharacter( int idx ){ return validCharacters[idx]; } public int convertCharacterToIndex( char c ){ return charToIdxMap.get(c); } public char getRandomCharacter(){ return validCharacters[(int) (rng.nextDouble()*validCharacters.length)]; } public boolean hasNext() { return exampleStartOffsets.size() > 0; } public DataSet next() { return next(miniBatchSize); } public DataSet next(int num) { if( exampleStartOffsets.size() == 0 ) throw new NoSuchElementException(); int currMinibatchSize = Math.min(num, exampleStartOffsets.size()); //Allocate space: //Note the order here: // dimension 0 = number of examples in minibatch // dimension 1 = size of each vector (i.e., number of characters) // dimension 2 = length of each time series/example //Why 'f' order here? See http://deeplearning4j.org/usingrnns.html#data section "Alternative: Implementing a custom DataSetIterator" INDArray input = Nd4j.create(new int[]{currMinibatchSize,validCharacters.length,exampleLength}, 'f'); INDArray labels = Nd4j.create(new int[]{currMinibatchSize,validCharacters.length,exampleLength}, 'f'); for( int i=0; i<currMinibatchSize; i++ ){ int startIdx = exampleStartOffsets.removeFirst(); int endIdx = startIdx + exampleLength; int currCharIdx = charToIdxMap.get(fileCharacters[startIdx]); //Current input int c=0; for( int j=startIdx+1; j<endIdx; j++, c++ ){ int nextCharIdx = charToIdxMap.get(fileCharacters[j]); //Next character to predict input.putScalar(new int[]{i,currCharIdx,c}, 1.0); labels.putScalar(new int[]{i,nextCharIdx,c}, 1.0); currCharIdx = nextCharIdx; } } return new DataSet(input,labels); } public int totalExamples() { return (fileCharacters.length-1) / miniBatchSize - 2; } public int inputColumns() { return validCharacters.length; } public int totalOutcomes() { return validCharacters.length; } public void reset() { exampleStartOffsets.clear(); initializeOffsets(); } private void initializeOffsets() { //This defines the order in which parts of the file are fetched int nMinibatchesPerEpoch = (fileCharacters.length - 1) / exampleLength - 2; //-2: for end index, and for partial example for (int i = 0; i < nMinibatchesPerEpoch; i++) { exampleStartOffsets.add(i * exampleLength); } Collections.shuffle(exampleStartOffsets, rng); } public boolean resetSupported() { return true; } @Override public boolean asyncSupported() { return true; } public int batch() { return miniBatchSize; } public int cursor() { return totalExamples() - exampleStartOffsets.size(); } public int numExamples() { return totalExamples(); } public void setPreProcessor(DataSetPreProcessor preProcessor) { throw new UnsupportedOperationException("Not implemented"); } @Override public DataSetPreProcessor getPreProcessor() { throw new UnsupportedOperationException("Not implemented"); } @Override public List<String> getLabels() { throw new UnsupportedOperationException("Not implemented"); } @Override public void remove() { throw new UnsupportedOperationException(); } }