package edu.stanford.nlp.sentiment;
import edu.stanford.nlp.classify.*;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.optimization.QNMinimizer;
import edu.stanford.nlp.pipeline.Annotation;
import edu.stanford.nlp.pipeline.StanfordCoreNLP;
import edu.stanford.nlp.simple.Document;
import edu.stanford.nlp.simple.SentimentClass;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Lazy;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.logging.Redwood;
import edu.stanford.nlp.util.logging.RedwoodConfiguration;
import java.io.File;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.Properties;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Pattern;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import static edu.stanford.nlp.util.logging.Redwood.Util.*;
/**
* A simple sentiment classifier, inspired by Sida's Naive Bayes SVM
* paper.
* The main goal of this class is to avoid the parse tree requirement of
* the RNN approach at: {@link SentimentPipeline}.
*
* @author <a href="mailto:angeli@cs.stanford.edu">Gabor Angeli</a>
*/
public class SimpleSentiment {
/**
* A logger for this class.
*/
private static final Redwood.RedwoodChannels log = Redwood.channels(SimpleSentiment.class);
/** An appropriate pipeline object for featurizing training data */
private static Lazy<StanfordCoreNLP> pipeline = Lazy.of(() -> {
Properties props = new Properties();
props.setProperty("annotators", "tokenize,ssplit,pos,lemma");
props.setProperty("language", "english");
props.setProperty("ssplit.isOneSentence", "true");
props.setProperty("tokenize.class", "PTBTokenizer");
props.setProperty("tokenize.language", "en");
return new StanfordCoreNLP(props);
});
/**
* A single datum (presumably read from a training file) that encodes
* a sentence and an associated sentiment value.
*/
private static class SentimentDatum {
/** The sentence to classify. */
public final String sentence;
/** The sentiment class of the sentence */
public final SentimentClass sentiment;
/** The trivial constructor */
private SentimentDatum(String sentence, SentimentClass sentiment) {
this.sentence = sentence;
this.sentiment = sentiment;
}
/** Annotate this datum, and return it as a CoreMap. */
CoreMap asCoreMap() {
Annotation ann;
if ("".equals(sentence.trim())) {
switch (sentiment) {
case VERY_POSITIVE:
ann = new Annotation("cats are super awesome!");
break;
case POSITIVE:
ann = new Annotation("cats are great");
break;
case NEUTRAL:
ann = new Annotation("cats have tails");
break;
case NEGATIVE:
ann = new Annotation("cats suck");
break;
case VERY_NEGATIVE:
ann = new Annotation("cats are literally the worst, I can't even.");
break;
default:
throw new IllegalStateException();
}
} else {
ann = new Annotation(sentence);
}
pipeline.get().annotate(ann);
return ann.get(CoreAnnotations.SentencesAnnotation.class).get(0);
}
}
/** A simple regex for alpha words. That is, words matching [a-zA-Z] */
private static final Pattern alpha = Pattern.compile("[a-zA-Z]+");
/** A simple regex for number tokens. That is, words matching [0-9] */
private static final Pattern number = Pattern.compile("[0-9]+");
/**
* The underlying classifier we have trained to detect sentiment.
*/
private final Classifier<SentimentClass, String> impl;
/**
* Featurize a given sentence.
*
* @param sentence The sentence to featurize.
*
* @return A counter encoding the featurized sentence.
*/
private static Counter<String> featurize(CoreMap sentence) {
ClassicCounter<String> features = new ClassicCounter<>();
String lastLemma = "^";
for (CoreLabel token : sentence.get(CoreAnnotations.TokensAnnotation.class)) {
String lemma = token.lemma().toLowerCase();
if (number.matcher(lemma).matches()) {
features.incrementCount("**num**");
} else {
features.incrementCount(lemma);
}
if (alpha.matcher(lemma).matches()) {
features.incrementCount(lastLemma + "__" + lemma);
lastLemma = lemma;
}
}
features.incrementCount(lastLemma + "__$");
return features;
}
/**
* Create a new sentiment classifier object.
* This is really just a shallow wrapper around a classifier...
*
* @param impl The classifier doing the heavy lifting.
*/
private SimpleSentiment(Classifier<SentimentClass, String> impl) {
this.impl = impl;
}
/**
* Get the sentiment of a sentence.
*
* @param sentence The sentence as a core map.
* POS tags and Lemmas are a prerequisite.
* See {@link edu.stanford.nlp.ling.CoreAnnotations.PartOfSpeechAnnotation} and
* {@link edu.stanford.nlp.ling.CoreAnnotations.LemmaAnnotation}.
*
* @return The sentiment class of this sentence.
*/
public SentimentClass classify(CoreMap sentence) {
Counter<String> features = featurize(sentence);
RVFDatum<SentimentClass, String> datum = new RVFDatum<>(features);
return impl.classOf(datum);
}
/**
* @see SimpleSentiment#classify(CoreMap)
*/
public SentimentClass classify(String text) {
Annotation ann = new Annotation(text);
pipeline.get().annotate(ann);
CoreMap sentence = ann.get(CoreAnnotations.SentencesAnnotation.class).get(0);
Counter<String> features = featurize(sentence);
RVFDatum<SentimentClass, String> datum = new RVFDatum<>(features);
return impl.classOf(datum);
}
/**
* Train a sentiment model from a set of data.
*
* @param data The data to train the model from.
* @param modelLocation An optional location to save the model.
* Note that this stream will be closed in this method,
* and should not be written to thereafter.
*
* @return A sentiment classifier, ready to use.
*/
@SuppressWarnings({"OptionalUsedAsFieldOrParameterType", "ConstantConditions"})
public static SimpleSentiment train(
Stream<SentimentDatum> data,
Optional<OutputStream> modelLocation) {
// Some useful variables configuring how we train
boolean useL1 = true;
double sigma = 1.0;
int featureCountThreshold = 5;
// Featurize the data
forceTrack("Featurizing");
RVFDataset<SentimentClass, String> dataset = new RVFDataset<>();
AtomicInteger datasize = new AtomicInteger(0);
Counter<SentimentClass> distribution = new ClassicCounter<>();
data.unordered().parallel()
.map(datum -> {
if (datasize.incrementAndGet() % 10000 == 0) {
log("Added " + datasize.get() + " datums");
}
return new RVFDatum<>(featurize(datum.asCoreMap()), datum.sentiment);
})
.forEach(x -> {
synchronized (dataset) {
distribution.incrementCount(x.label());
dataset.add(x);
}
});
endTrack("Featurizing");
// Print label distribution
startTrack("Distribution");
for (SentimentClass label : SentimentClass.values()) {
log(String.format("%7d", (int) distribution.getCount(label)) + " " + label);
}
endTrack("Distribution");
// Train the classifier
forceTrack("Training");
if (featureCountThreshold > 1) {
dataset.applyFeatureCountThreshold(featureCountThreshold);
}
dataset.randomize(42L);
LinearClassifierFactory<SentimentClass, String> factory = new LinearClassifierFactory<>();
factory.setVerbose(true);
try {
factory.setMinimizerCreator(() -> {
QNMinimizer minimizer = new QNMinimizer();
if (useL1) {
minimizer.useOWLQN(true, 1 / (sigma * sigma));
} else {
factory.setSigma(sigma);
}
return minimizer;
});
} catch (Exception ignored) {}
factory.setSigma(sigma);
LinearClassifier<SentimentClass, String> classifier = factory.trainClassifier(dataset);
// Optionally save the model
modelLocation.ifPresent(stream -> {
try {
ObjectOutputStream oos = new ObjectOutputStream(stream);
oos.writeObject(classifier);
oos.close();
} catch (IOException e) {
log.err("Could not save model to stream!");
}
});
endTrack("Training");
// Evaluate the model
forceTrack("Evaluating");
factory.setVerbose(false);
double sumAccuracy = 0.0;
Counter<SentimentClass> sumP = new ClassicCounter<>();
Counter<SentimentClass> sumR = new ClassicCounter<>();
int numFolds = 4;
for (int fold = 0; fold < numFolds; ++fold) {
Pair<GeneralDataset<SentimentClass, String>, GeneralDataset<SentimentClass, String>> trainTest = dataset.splitOutFold(fold, numFolds);
LinearClassifier<SentimentClass, String> foldClassifier = factory.trainClassifierWithInitialWeights(trainTest.first, classifier); // convex objective, so this should be OK
sumAccuracy += foldClassifier.evaluateAccuracy(trainTest.second);
for (SentimentClass label : SentimentClass.values()) {
Pair<Double, Double> pr = foldClassifier.evaluatePrecisionAndRecall(trainTest.second, label);
sumP.incrementCount(label, pr.first);
sumP.incrementCount(label, pr.second);
}
}
DecimalFormat df = new DecimalFormat("0.000%");
log.info("----------");
double aveAccuracy = sumAccuracy / ((double) numFolds);
log.info("" + numFolds + "-fold accuracy: " + df.format(aveAccuracy));
log.info("");
for (SentimentClass label : SentimentClass.values()) {
double p = sumP.getCount(label) / numFolds;
double r = sumR.getCount(label) / numFolds;
log.info(label + " (P) = " + df.format(p));
log.info(label + " (R) = " + df.format(r));
log.info(label + " (F1) = " + df.format(2 * p * r / (p + r)));
log.info("");
}
log.info("----------");
endTrack("Evaluating");
// Return
return new SimpleSentiment(classifier);
}
private static Stream<SentimentDatum> imdb(String path, SentimentClass label) {
return StreamSupport.stream(
IOUtils.iterFilesRecursive(new File(path)).spliterator(), true)
.map(x -> {
try {
return new SentimentDatum(IOUtils.slurpFile(x), label);
} catch (IOException e) {
throw new RuntimeIOException(e);
}
});
}
private static Stream<SentimentDatum> stanford(String path) {
return StreamSupport.stream(
IOUtils.readLines(path).spliterator(), true
).map(line -> {
String[] fields = line.split("\t");
if (fields.length < 4 || "Sentiment".equalsIgnoreCase(fields[3]) ||
fields[2].equals("")) {
return new SentimentDatum("Cats have tails", SentimentClass.NEUTRAL);
} else {
String text = fields[2];
int sentiment = Integer.parseInt(fields[3]);
return new SentimentDatum(text, SentimentClass.fromInt(sentiment));
}
});
}
private static Stream<SentimentDatum> twitter(String path) {
return StreamSupport.stream(
IOUtils.readLines(path).spliterator(), true
).map(line -> {
List<String> fields = Arrays.asList(line.split(","));
if (fields.size() < 3 || "Sentiment".equalsIgnoreCase(fields.get(1)) ||
fields.get(3).equals("")) {
return new SentimentDatum("Cats have tails", SentimentClass.NEUTRAL);
} else {
int sentiment = Integer.parseInt(fields.get(1));
String text = StringUtils.join(fields.subList(3, fields.size()), ",");
return new SentimentDatum(text, SentimentClass.fromInt(sentiment));
}
});
}
private static Stream<SentimentDatum> unlabelled(String path) throws IOException {
return StreamSupport.stream(
IOUtils.iterFilesRecursive(new File(path)).spliterator(), true)
.flatMap(x -> new Document(IOUtils.slurpReader(IOUtils.readerFromFile(x)))
.sentences()
.stream()
.map(y -> new SentimentDatum(y.text(), SentimentClass.NEUTRAL)));
}
public static void main(String[] args) throws IOException {
RedwoodConfiguration.standard().apply();
startTrack("main");
// Read the data
Stream<SentimentDatum> data =
Stream.concat(
Stream.concat(
Stream.concat(
imdb("/users/gabor/tmp/aclImdb/train/pos", SentimentClass.POSITIVE),
imdb("/users/gabor/tmp/aclImdb/train/neg", SentimentClass.NEGATIVE)),
Stream.concat(
imdb("/users/gabor/tmp/aclImdb/test/pos", SentimentClass.POSITIVE),
imdb("/users/gabor/tmp/aclImdb/test/neg", SentimentClass.NEGATIVE)
)
),
Stream.concat(
Stream.concat(
stanford("/users/gabor/tmp/train.tsv"),
stanford("/users/gabor/tmp/test.tsv")
),
Stream.concat(
twitter("/users/gabor/tmp/twitter.csv"),
unlabelled("/users/gabor/tmp/wikipedia")
)
)
);
// Train the model
OutputStream stream = IOUtils.getFileOutputStream("/users/gabor/tmp/model.ser.gz");
SimpleSentiment classifier = SimpleSentiment.train(data, Optional.of(stream));
stream.close();
log.info(classifier.classify("I think life is great"));
endTrack("main");
// 85.8
}
}