package org.deeplearning4j.examples.nlp.glove;
import org.datavec.api.util.ClassPathResource;
import org.deeplearning4j.models.glove.Glove;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.util.Collection;
/**
* @author raver119@gmail.com
*/
public class GloVeExample {
private static final Logger log = LoggerFactory.getLogger(GloVeExample.class);
public static void main(String[] args) throws Exception {
File inputFile = new ClassPathResource("raw_sentences.txt").getFile();
// creating SentenceIterator wrapping our training corpus
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
// Split on white spaces in the line to get words
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
Glove glove = new Glove.Builder()
.iterate(iter)
.tokenizerFactory(t)
.alpha(0.75)
.learningRate(0.1)
// number of epochs for training
.epochs(25)
// cutoff for weighting function
.xMax(100)
// training is done in batches taken from training corpus
.batchSize(1000)
// if set to true, batches will be shuffled before training
.shuffle(true)
// if set to true word pairs will be built in both directions, LTR and RTL
.symmetric(true)
.build();
glove.fit();
double simD = glove.similarity("day", "night");
log.info("Day/night similarity: " + simD);
Collection<String> words = glove.wordsNearest("day", 10);
log.info("Nearest words to 'day': " + words);
System.exit(0);
}
}