package org.deeplearning4j.examples.multigpu.w2vsentiment;
import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParameterException;
import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.slf4j.Logger;
import java.io.*;
import java.net.URL;
import java.util.concurrent.atomic.AtomicInteger;
/**
* This class generates binary datasets out of raw text for further use on gpu for faster tuning sessions.
* Idea behind this approach is simple: if you're going to run over the same corpus over and over, it's more efficient
* to do all preprocessing once, and save datasets as binary files, instead of doing the same stuff over and over on the fly
*
*/
public class DataSetsBuilder {
private static final Logger log = org.slf4j.LoggerFactory.getLogger(DataSetsBuilder.class);
/** Data URL for downloading */
public static final String DATA_URL = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz";
/** Location to save and extract the training/testing data */
public static final String DATA_PATH = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "dl4j_w2vSentiment/");
/** Location (local file system) for the Google News vectors. Set this manually. */
//public static final String WORD_VECTORS_PATH = "/PATH/TO/YOUR/VECTORS/GoogleNews-vectors-negative300.bin.gz";
public static final String WORD_VECTORS_PATH = "/home/raver119/develop/GoogleNews-vectors-negative300.bin.gz";
public static final String TRAIN_PATH = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "dl4j_w2vSentiment_train/");
public static final String TEST_PATH = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "dl4j_w2vSentiment_test/");
@Parameter(names = {"-b","--batch"}, description = "BatchSize")
private int batchSize = 64;
@Parameter(names = {"-l","--length"}, description = "Truncate max review length to")
private int truncateReviewsToLength = 256;
public void run(String[] args) throws Exception {
JCommander jcmdr = new JCommander(this);
try {
jcmdr.parse(args);
} catch (ParameterException e) {
//User provides invalid input -> print the usage info
jcmdr.usage();
try {
Thread.sleep(500);
} catch (Exception e2) {
}
System.exit(1);
}
if(WORD_VECTORS_PATH.startsWith("/PATH/TO/YOUR/VECTORS/")){
throw new RuntimeException("Please set the WORD_VECTORS_PATH before running this example");
}
downloadData();
WordVectors wordVectors = WordVectorSerializer.loadStaticModel(new File(WORD_VECTORS_PATH));
SentimentExampleIterator train = new SentimentExampleIterator(DATA_PATH, wordVectors, batchSize, truncateReviewsToLength, true);
SentimentExampleIterator test = new SentimentExampleIterator(DATA_PATH, wordVectors, batchSize, truncateReviewsToLength, false);
log.info("Saving test data...");
// saveDatasets(test, TEST_PATH);
log.info("Saving train data...");
saveDatasets(train, TRAIN_PATH);
}
private static void downloadData() throws Exception {
//Create directory if required
File directory = new File(DATA_PATH);
if(!directory.exists()) directory.mkdir();
//Download file:
String archizePath = DATA_PATH + "aclImdb_v1.tar.gz";
File archiveFile = new File(archizePath);
String extractedPath = DATA_PATH + "aclImdb";
File extractedFile = new File(extractedPath);
if( !archiveFile.exists() ){
System.out.println("Starting data download (80MB)...");
FileUtils.copyURLToFile(new URL(DATA_URL), archiveFile);
System.out.println("Data (.tar.gz file) downloaded to " + archiveFile.getAbsolutePath());
//Extract tar.gz file to output directory
extractTarGz(archizePath, DATA_PATH);
} else {
//Assume if archive (.tar.gz) exists, then data has already been extracted
System.out.println("Data (.tar.gz file) already exists at " + archiveFile.getAbsolutePath());
if( !extractedFile.exists()){
//Extract tar.gz file to output directory
extractTarGz(archizePath, DATA_PATH);
} else {
System.out.println("Data (extracted) already exists at " + extractedFile.getAbsolutePath());
}
}
}
private static final int BUFFER_SIZE = 4096;
private static void extractTarGz(String filePath, String outputPath) throws IOException {
int fileCount = 0;
int dirCount = 0;
System.out.print("Extracting files");
try(TarArchiveInputStream tais = new TarArchiveInputStream(
new GzipCompressorInputStream( new BufferedInputStream( new FileInputStream(filePath))))){
TarArchiveEntry entry;
/** Read the tar entries using the getNextEntry method **/
while ((entry = (TarArchiveEntry) tais.getNextEntry()) != null) {
//System.out.println("Extracting file: " + entry.getName());
//Create directories as required
if (entry.isDirectory()) {
new File(outputPath + entry.getName()).mkdirs();
dirCount++;
}else {
int count;
byte data[] = new byte[BUFFER_SIZE];
FileOutputStream fos = new FileOutputStream(outputPath + entry.getName());
BufferedOutputStream dest = new BufferedOutputStream(fos,BUFFER_SIZE);
while ((count = tais.read(data, 0, BUFFER_SIZE)) != -1) {
dest.write(data, 0, count);
}
dest.close();
fileCount++;
}
if(fileCount % 1000 == 0) System.out.print(".");
}
}
System.out.println("\n" + fileCount + " files and " + dirCount + " directories extracted to: " + outputPath);
}
protected void saveDatasets(DataSetIterator iterator, String dir) {
AtomicInteger counter = new AtomicInteger(0);
new File(dir).mkdirs();
while (iterator.hasNext()) {
String path = FilenameUtils.concat(dir, "dataset-" + (counter.getAndIncrement()) + ".bin");
iterator.next().save(new File(path));
if (counter.get() % 500 == 0)
log.info("{} datasets saved so far...", counter.get());
}
}
public static void main(String[] args) throws Exception {
new DataSetsBuilder().run(args);
}
}