package edu.stanford.nlp.sentiment;
import java.util.List;
import java.util.Map;
import org.ejml.simple.SimpleMatrix;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.neural.NeuralUtils;
import edu.stanford.nlp.neural.SimpleTensor;
import edu.stanford.nlp.neural.rnn.RNNCoreAnnotations;
import edu.stanford.nlp.optimization.AbstractCachingDiffFunction;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.CollectionUtils;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.TwoDimensionalMap;
import edu.stanford.nlp.util.concurrent.MulticoreWrapper;
import edu.stanford.nlp.util.concurrent.ThreadsafeProcessor;
import edu.stanford.nlp.util.logging.Redwood;
// TODO: get rid of the word Sentiment everywhere
public class SentimentCostAndGradient extends AbstractCachingDiffFunction {
private static final Redwood.RedwoodChannels log = Redwood.channels(SentimentCostAndGradient.class);
private final SentimentModel model;
private final List<Tree> trainingBatch;
public SentimentCostAndGradient(SentimentModel model, List<Tree> trainingBatch) {
this.model = model;
this.trainingBatch = trainingBatch;
}
@Override
public int domainDimension() {
// TODO: cache this for speed?
return model.totalParamSize();
}
private static double sumError(Tree tree) {
if (tree.isLeaf()) {
return 0.0;
} else if (tree.isPreTerminal()) {
return RNNCoreAnnotations.getPredictionError(tree);
} else {
double error = 0.0;
for (Tree child : tree.children()) {
error += sumError(child);
}
return RNNCoreAnnotations.getPredictionError(tree) + error;
}
}
/**
* Returns the index with the highest value in the {@code predictions} matrix.
* Indexed from 0.
*/
private static int getPredictedClass(SimpleMatrix predictions) {
int argmax = 0;
for (int i = 1; i < predictions.getNumElements(); ++i) {
if (predictions.get(i) > predictions.get(argmax)) {
argmax = i;
}
}
return argmax;
}
private static class ModelDerivatives {
// We use TreeMap for each of these so that they stay in a canonical sorted order
// binaryTD stands for Transform Derivatives (see the SentimentModel)
public final TwoDimensionalMap<String, String, SimpleMatrix> binaryTD;
// the derivatives of the tensors for the binary nodes
// will be empty if we aren't using tensors
public final TwoDimensionalMap<String, String, SimpleTensor> binaryTensorTD;
// binaryCD stands for Classification Derivatives
// if we combined classification derivatives, we just use an empty map
public final TwoDimensionalMap<String, String, SimpleMatrix> binaryCD;
// unaryCD stands for Classification Derivatives
public final Map<String, SimpleMatrix> unaryCD;
// word vector derivatives
// will be filled on an as-needed basis, as opposed to having all
// the words with a lot of empty vectors
public final Map<String, SimpleMatrix> wordVectorD;
public double error = 0.0;
public ModelDerivatives(SentimentModel model) {
binaryTD = initDerivatives(model.binaryTransform);
binaryTensorTD = (model.op.useTensors) ? initTensorDerivatives(model.binaryTensors) : TwoDimensionalMap.treeMap();
binaryCD = (!model.op.combineClassification) ? initDerivatives(model.binaryClassification) : TwoDimensionalMap.treeMap();
unaryCD = initDerivatives(model.unaryClassification);
// wordVectorD will be filled on an as-needed basis
wordVectorD = Generics.newTreeMap();
}
public void add(ModelDerivatives other) {
addMatrices(binaryTD, other.binaryTD);
addTensors(binaryTensorTD, other.binaryTensorTD);
addMatrices(binaryCD, other.binaryCD);
addMatrices(unaryCD, other.unaryCD);
addMatrices(wordVectorD, other.wordVectorD);
error += other.error;
}
/**
* Add matrices from the second map to the first map, in place.
*/
public static void addMatrices(TwoDimensionalMap<String, String, SimpleMatrix> first,
TwoDimensionalMap<String, String, SimpleMatrix> second) {
for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : first) {
if (second.contains(entry.getFirstKey(), entry.getSecondKey())) {
first.put(entry.getFirstKey(), entry.getSecondKey(), entry.getValue().plus(second.get(entry.getFirstKey(), entry.getSecondKey())));
}
}
for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : second) {
if (!first.contains(entry.getFirstKey(), entry.getSecondKey())) {
first.put(entry.getFirstKey(), entry.getSecondKey(), entry.getValue());
}
}
}
/**
* Add tensors from the second map to the first map, in place.
*/
public static void addTensors(TwoDimensionalMap<String, String, SimpleTensor> first,
TwoDimensionalMap<String, String, SimpleTensor> second) {
for (TwoDimensionalMap.Entry<String, String, SimpleTensor> entry : first) {
if (second.contains(entry.getFirstKey(), entry.getSecondKey())) {
first.put(entry.getFirstKey(), entry.getSecondKey(), entry.getValue().plus(second.get(entry.getFirstKey(), entry.getSecondKey())));
}
}
for (TwoDimensionalMap.Entry<String, String, SimpleTensor> entry : second) {
if (!first.contains(entry.getFirstKey(), entry.getSecondKey())) {
first.put(entry.getFirstKey(), entry.getSecondKey(), entry.getValue());
}
}
}
/**
* Add matrices from the second map to the first map, in place.
*/
public static void addMatrices(Map<String, SimpleMatrix> first,
Map<String, SimpleMatrix> second) {
for (Map.Entry<String, SimpleMatrix> entry : first.entrySet()) {
if (second.containsKey(entry.getKey())) {
first.put(entry.getKey(), entry.getValue().plus(second.get(entry.getKey())));
}
}
for (Map.Entry<String, SimpleMatrix> entry : second.entrySet()) {
if (!first.containsKey(entry.getKey())) {
first.put(entry.getKey(), entry.getValue());
}
}
}
/**
* Init a TwoDimensionalMap with 0 matrices for all the matrices in the original map.
*/
private static TwoDimensionalMap<String, String, SimpleMatrix> initDerivatives(TwoDimensionalMap<String, String, SimpleMatrix> map) {
TwoDimensionalMap<String, String, SimpleMatrix> derivatives = TwoDimensionalMap.treeMap();
for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : map) {
int numRows = entry.getValue().numRows();
int numCols = entry.getValue().numCols();
derivatives.put(entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(numRows, numCols));
}
return derivatives;
}
/**
* Init a TwoDimensionalMap with 0 tensors for all the tensors in the original map.
*/
private static TwoDimensionalMap<String, String, SimpleTensor> initTensorDerivatives(TwoDimensionalMap<String, String, SimpleTensor> map) {
TwoDimensionalMap<String, String, SimpleTensor> derivatives = TwoDimensionalMap.treeMap();
for (TwoDimensionalMap.Entry<String, String, SimpleTensor> entry : map) {
int numRows = entry.getValue().numRows();
int numCols = entry.getValue().numCols();
int numSlices = entry.getValue().numSlices();
derivatives.put(entry.getFirstKey(), entry.getSecondKey(), new SimpleTensor(numRows, numCols, numSlices));
}
return derivatives;
}
/**
* Init a Map with 0 matrices for all the matrices in the original map.
*/
private static Map<String, SimpleMatrix> initDerivatives(Map<String, SimpleMatrix> map) {
Map<String, SimpleMatrix> derivatives = Generics.newTreeMap();
for (Map.Entry<String, SimpleMatrix> entry : map.entrySet()) {
int numRows = entry.getValue().numRows();
int numCols = entry.getValue().numCols();
derivatives.put(entry.getKey(), new SimpleMatrix(numRows, numCols));
}
return derivatives;
}
}
private ModelDerivatives scoreDerivatives(List<Tree> trainingBatch) {
// "final" makes this as fast as having separate maps declared in this function
final ModelDerivatives derivatives = new ModelDerivatives(model);
List<Tree> forwardPropTrees = Generics.newArrayList();
for (Tree tree : trainingBatch) {
Tree trainingTree = tree.deepCopy();
// this will attach the error vectors and the node vectors
// to each node in the tree
forwardPropagateTree(trainingTree);
forwardPropTrees.add(trainingTree);
}
for (Tree tree : forwardPropTrees) {
backpropDerivativesAndError(tree, derivatives.binaryTD, derivatives.binaryCD, derivatives.binaryTensorTD, derivatives.unaryCD, derivatives.wordVectorD);
derivatives.error += sumError(tree);
}
return derivatives;
}
class ScoringProcessor implements ThreadsafeProcessor<List<Tree>, ModelDerivatives> {
@Override
public ModelDerivatives process(List<Tree> trainingBatch) {
return scoreDerivatives(trainingBatch);
}
@Override
public ThreadsafeProcessor<List<Tree>, ModelDerivatives> newInstance() {
// should be threadsafe
return this;
}
}
@Override
public void calculate(double[] theta) {
model.vectorToParams(theta);
final ModelDerivatives derivatives;
if (model.op.trainOptions.nThreads == 1) {
derivatives = scoreDerivatives(trainingBatch);
} else {
// TODO: because some addition operations happen in different
// orders now, this results in slightly different values, which
// over time add up to significantly different models even when
// given the same random seed. Probably not a big deal.
// To be more specific, for trees T1, T2, T3, ... Tn,
// when using one thread, we sum the derivatives T1 + T2 ...
// When using multiple threads, we first sum T1 + ... + Tk,
// then sum Tk+1 + ... + T2k, etc, for split size k.
// The splits are then summed in order.
// This different sum order results in slightly different numbers.
MulticoreWrapper<List<Tree>, ModelDerivatives> wrapper =
new MulticoreWrapper<>(model.op.trainOptions.nThreads, new ScoringProcessor());
// use wrapper.nThreads in case the number of threads was automatically changed
for (List<Tree> chunk : CollectionUtils.partitionIntoFolds(trainingBatch, wrapper.nThreads())) {
wrapper.put(chunk);
}
wrapper.join();
derivatives = new ModelDerivatives(model);
while (wrapper.peek()) {
ModelDerivatives batchDerivatives = wrapper.poll();
derivatives.add(batchDerivatives);
}
}
// scale the error by the number of sentences so that the
// regularization isn't drowned out for large training batchs
double scale = (1.0 / trainingBatch.size());
value = derivatives.error * scale;
value += scaleAndRegularize(derivatives.binaryTD, model.binaryTransform, scale, model.op.trainOptions.regTransformMatrix, false);
value += scaleAndRegularize(derivatives.binaryCD, model.binaryClassification, scale, model.op.trainOptions.regClassification, true);
value += scaleAndRegularizeTensor(derivatives.binaryTensorTD, model.binaryTensors, scale, model.op.trainOptions.regTransformTensor);
value += scaleAndRegularize(derivatives.unaryCD, model.unaryClassification, scale, model.op.trainOptions.regClassification, false, true);
value += scaleAndRegularize(derivatives.wordVectorD, model.wordVectors, scale, model.op.trainOptions.regWordVector, true, false);
derivative = NeuralUtils.paramsToVector(theta.length, derivatives.binaryTD.valueIterator(), derivatives.binaryCD.valueIterator(), SimpleTensor.iteratorSimpleMatrix(derivatives.binaryTensorTD.valueIterator()), derivatives.unaryCD.values().iterator(), derivatives.wordVectorD.values().iterator());
}
private static double scaleAndRegularize(TwoDimensionalMap<String, String, SimpleMatrix> derivatives,
TwoDimensionalMap<String, String, SimpleMatrix> currentMatrices,
double scale, double regCost, boolean dropBiasColumn) {
double cost = 0.0; // the regularization cost
for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : currentMatrices) {
SimpleMatrix D = derivatives.get(entry.getFirstKey(), entry.getSecondKey());
SimpleMatrix regMatrix = entry.getValue();
if (dropBiasColumn) {
regMatrix = new SimpleMatrix(regMatrix);
regMatrix.insertIntoThis(0, regMatrix.numCols() - 1, new SimpleMatrix(regMatrix.numRows(), 1));
}
D = D.scale(scale).plus(regMatrix.scale(regCost));
derivatives.put(entry.getFirstKey(), entry.getSecondKey(), D);
cost += regMatrix.elementMult(regMatrix).elementSum() * regCost / 2.0;
}
return cost;
}
private static double scaleAndRegularize(Map<String, SimpleMatrix> derivatives,
Map<String, SimpleMatrix> currentMatrices,
double scale, double regCost,
boolean activeMatricesOnly, boolean dropBiasColumn) {
double cost = 0.0; // the regularization cost
for (Map.Entry<String, SimpleMatrix> entry : currentMatrices.entrySet()) {
SimpleMatrix D = derivatives.get(entry.getKey());
if (activeMatricesOnly && D == null) {
// Fill in an emptpy matrix so the length of theta can match.
// TODO: might want to allow for sparse parameter vectors
derivatives.put(entry.getKey(), new SimpleMatrix(entry.getValue().numRows(), entry.getValue().numCols()));
continue;
}
SimpleMatrix regMatrix = entry.getValue();
if (dropBiasColumn) {
regMatrix = new SimpleMatrix(regMatrix);
regMatrix.insertIntoThis(0, regMatrix.numCols() - 1, new SimpleMatrix(regMatrix.numRows(), 1));
}
D = D.scale(scale).plus(regMatrix.scale(regCost));
derivatives.put(entry.getKey(), D);
cost += regMatrix.elementMult(regMatrix).elementSum() * regCost / 2.0;
}
return cost;
}
private static double scaleAndRegularizeTensor(TwoDimensionalMap<String, String, SimpleTensor> derivatives,
TwoDimensionalMap<String, String, SimpleTensor> currentMatrices,
double scale,
double regCost) {
double cost = 0.0; // the regularization cost
for (TwoDimensionalMap.Entry<String, String, SimpleTensor> entry : currentMatrices) {
SimpleTensor D = derivatives.get(entry.getFirstKey(), entry.getSecondKey());
D = D.scale(scale).plus(entry.getValue().scale(regCost));
derivatives.put(entry.getFirstKey(), entry.getSecondKey(), D);
cost += entry.getValue().elementMult(entry.getValue()).elementSum() * regCost / 2.0;
}
return cost;
}
private void backpropDerivativesAndError(Tree tree,
TwoDimensionalMap<String, String, SimpleMatrix> binaryTD,
TwoDimensionalMap<String, String, SimpleMatrix> binaryCD,
TwoDimensionalMap<String, String, SimpleTensor> binaryTensorTD,
Map<String, SimpleMatrix> unaryCD,
Map<String, SimpleMatrix> wordVectorD) {
SimpleMatrix delta = new SimpleMatrix(model.op.numHid, 1);
backpropDerivativesAndError(tree, binaryTD, binaryCD, binaryTensorTD, unaryCD, wordVectorD, delta);
}
private void backpropDerivativesAndError(Tree tree,
TwoDimensionalMap<String, String, SimpleMatrix> binaryTD,
TwoDimensionalMap<String, String, SimpleMatrix> binaryCD,
TwoDimensionalMap<String, String, SimpleTensor> binaryTensorTD,
Map<String, SimpleMatrix> unaryCD,
Map<String, SimpleMatrix> wordVectorD,
SimpleMatrix deltaUp) {
if (tree.isLeaf()) {
return;
}
SimpleMatrix currentVector = RNNCoreAnnotations.getNodeVector(tree);
String category = tree.label().value();
category = model.basicCategory(category);
// Build a vector that looks like 0,0,1,0,0 with an indicator for the correct class
SimpleMatrix goldLabel = new SimpleMatrix(model.numClasses, 1);
int goldClass = RNNCoreAnnotations.getGoldClass(tree);
if (goldClass >= 0) {
goldLabel.set(goldClass, 1.0);
}
double nodeWeight = model.op.trainOptions.getClassWeight(goldClass);
SimpleMatrix predictions = RNNCoreAnnotations.getPredictions(tree);
// If this is an unlabeled class, set deltaClass to 0. We could
// make this more efficient by eliminating various of the below
// calculations, but this would be the easiest way to handle the
// unlabeled class
SimpleMatrix deltaClass = goldClass >= 0 ? predictions.minus(goldLabel).scale(nodeWeight) : new SimpleMatrix(predictions.numRows(), predictions.numCols());
SimpleMatrix localCD = deltaClass.mult(NeuralUtils.concatenateWithBias(currentVector).transpose());
double error = -(NeuralUtils.elementwiseApplyLog(predictions).elementMult(goldLabel).elementSum());
error = error * nodeWeight;
RNNCoreAnnotations.setPredictionError(tree, error);
if (tree.isPreTerminal()) { // below us is a word vector
unaryCD.put(category, unaryCD.get(category).plus(localCD));
String word = tree.children()[0].label().value();
word = model.getVocabWord(word);
//SimpleMatrix currentVectorDerivative = NeuralUtils.elementwiseApplyTanhDerivative(currentVector);
//SimpleMatrix deltaFromClass = model.getUnaryClassification(category).transpose().mult(deltaClass);
//SimpleMatrix deltaFull = deltaFromClass.extractMatrix(0, model.op.numHid, 0, 1).plus(deltaUp);
//SimpleMatrix wordDerivative = deltaFull.elementMult(currentVectorDerivative);
//wordVectorD.put(word, wordVectorD.get(word).plus(wordDerivative));
SimpleMatrix currentVectorDerivative = NeuralUtils.elementwiseApplyTanhDerivative(currentVector);
SimpleMatrix deltaFromClass = model.getUnaryClassification(category).transpose().mult(deltaClass);
deltaFromClass = deltaFromClass.extractMatrix(0, model.op.numHid, 0, 1).elementMult(currentVectorDerivative);
SimpleMatrix deltaFull = deltaFromClass.plus(deltaUp);
SimpleMatrix oldWordVectorD = wordVectorD.get(word);
if (oldWordVectorD == null) {
wordVectorD.put(word, deltaFull);
} else {
wordVectorD.put(word, oldWordVectorD.plus(deltaFull));
}
} else {
// Otherwise, this must be a binary node
String leftCategory = model.basicCategory(tree.children()[0].label().value());
String rightCategory = model.basicCategory(tree.children()[1].label().value());
if (model.op.combineClassification) {
unaryCD.put("", unaryCD.get("").plus(localCD));
} else {
binaryCD.put(leftCategory, rightCategory, binaryCD.get(leftCategory, rightCategory).plus(localCD));
}
SimpleMatrix currentVectorDerivative = NeuralUtils.elementwiseApplyTanhDerivative(currentVector);
SimpleMatrix deltaFromClass = model.getBinaryClassification(leftCategory, rightCategory).transpose().mult(deltaClass);
deltaFromClass = deltaFromClass.extractMatrix(0, model.op.numHid, 0, 1).elementMult(currentVectorDerivative);
SimpleMatrix deltaFull = deltaFromClass.plus(deltaUp);
SimpleMatrix leftVector = RNNCoreAnnotations.getNodeVector(tree.children()[0]);
SimpleMatrix rightVector = RNNCoreAnnotations.getNodeVector(tree.children()[1]);
SimpleMatrix childrenVector = NeuralUtils.concatenateWithBias(leftVector, rightVector);
SimpleMatrix W_df = deltaFull.mult(childrenVector.transpose());
binaryTD.put(leftCategory, rightCategory, binaryTD.get(leftCategory, rightCategory).plus(W_df));
SimpleMatrix deltaDown;
if (model.op.useTensors) {
SimpleTensor Wt_df = getTensorGradient(deltaFull, leftVector, rightVector);
binaryTensorTD.put(leftCategory, rightCategory, binaryTensorTD.get(leftCategory, rightCategory).plus(Wt_df));
deltaDown = computeTensorDeltaDown(deltaFull, leftVector, rightVector, model.getBinaryTransform(leftCategory, rightCategory), model.getBinaryTensor(leftCategory, rightCategory));
} else {
deltaDown = model.getBinaryTransform(leftCategory, rightCategory).transpose().mult(deltaFull);
}
SimpleMatrix leftDerivative = NeuralUtils.elementwiseApplyTanhDerivative(leftVector);
SimpleMatrix rightDerivative = NeuralUtils.elementwiseApplyTanhDerivative(rightVector);
SimpleMatrix leftDeltaDown = deltaDown.extractMatrix(0, deltaFull.numRows(), 0, 1);
SimpleMatrix rightDeltaDown = deltaDown.extractMatrix(deltaFull.numRows(), deltaFull.numRows() * 2, 0, 1);
backpropDerivativesAndError(tree.children()[0], binaryTD, binaryCD, binaryTensorTD, unaryCD, wordVectorD, leftDerivative.elementMult(leftDeltaDown));
backpropDerivativesAndError(tree.children()[1], binaryTD, binaryCD, binaryTensorTD, unaryCD, wordVectorD, rightDerivative.elementMult(rightDeltaDown));
}
}
private static SimpleMatrix computeTensorDeltaDown(SimpleMatrix deltaFull, SimpleMatrix leftVector, SimpleMatrix rightVector,
SimpleMatrix W, SimpleTensor Wt) {
SimpleMatrix WTDelta = W.transpose().mult(deltaFull);
SimpleMatrix WTDeltaNoBias = WTDelta.extractMatrix(0, deltaFull.numRows() * 2, 0, 1);
int size = deltaFull.getNumElements();
SimpleMatrix deltaTensor = new SimpleMatrix(size*2, 1);
SimpleMatrix fullVector = NeuralUtils.concatenate(leftVector, rightVector);
for (int slice = 0; slice < size; ++slice) {
SimpleMatrix scaledFullVector = fullVector.scale(deltaFull.get(slice));
deltaTensor = deltaTensor.plus(Wt.getSlice(slice).plus(Wt.getSlice(slice).transpose()).mult(scaledFullVector));
}
return deltaTensor.plus(WTDeltaNoBias);
}
private static SimpleTensor getTensorGradient(SimpleMatrix deltaFull, SimpleMatrix leftVector, SimpleMatrix rightVector) {
int size = deltaFull.getNumElements();
SimpleTensor Wt_df = new SimpleTensor(size*2, size*2, size);
// TODO: combine this concatenation with computeTensorDeltaDown?
SimpleMatrix fullVector = NeuralUtils.concatenate(leftVector, rightVector);
for (int slice = 0; slice < size; ++slice) {
Wt_df.setSlice(slice, fullVector.scale(deltaFull.get(slice)).mult(fullVector.transpose()));
}
return Wt_df;
}
/**
* This is the method to call for assigning labels and node vectors
* to the Tree. After calling this, each of the non-leaf nodes will
* have the node vector and the predictions of their classes
* assigned to that subtree's node. The annotations filled in are
* the RNNCoreAnnotations.NodeVector, Predictions, and
* PredictedClass. In general, PredictedClass will be the most
* useful annotation except when training.
*/
public void forwardPropagateTree(Tree tree) {
SimpleMatrix nodeVector; // initialized below or Exception thrown // = null;
SimpleMatrix classification; // initialized below or Exception thrown // = null;
if (tree.isLeaf()) {
// We do nothing for the leaves. The preterminals will
// calculate the classification for this word/tag. In fact, the
// recursion should not have gotten here (unless there are
// degenerate trees of just one leaf)
log.info("SentimentCostAndGradient: warning: We reached leaves in forwardPropagate: " + tree);
throw new AssertionError("We should not have reached leaves in forwardPropagate");
} else if (tree.isPreTerminal()) {
classification = model.getUnaryClassification(tree.label().value());
String word = tree.children()[0].label().value();
SimpleMatrix wordVector = model.getWordVector(word);
nodeVector = NeuralUtils.elementwiseApplyTanh(wordVector);
} else if (tree.children().length == 1) {
log.info("SentimentCostAndGradient: warning: Non-preterminal nodes of size 1: " + tree);
throw new AssertionError("Non-preterminal nodes of size 1 should have already been collapsed");
} else if (tree.children().length == 2) {
forwardPropagateTree(tree.children()[0]);
forwardPropagateTree(tree.children()[1]);
String leftCategory = tree.children()[0].label().value();
String rightCategory = tree.children()[1].label().value();
SimpleMatrix W = model.getBinaryTransform(leftCategory, rightCategory);
classification = model.getBinaryClassification(leftCategory, rightCategory);
SimpleMatrix leftVector = RNNCoreAnnotations.getNodeVector(tree.children()[0]);
SimpleMatrix rightVector = RNNCoreAnnotations.getNodeVector(tree.children()[1]);
SimpleMatrix childrenVector = NeuralUtils.concatenateWithBias(leftVector, rightVector);
if (model.op.useTensors) {
SimpleTensor tensor = model.getBinaryTensor(leftCategory, rightCategory);
SimpleMatrix tensorIn = NeuralUtils.concatenate(leftVector, rightVector);
SimpleMatrix tensorOut = tensor.bilinearProducts(tensorIn);
nodeVector = NeuralUtils.elementwiseApplyTanh(W.mult(childrenVector).plus(tensorOut));
} else {
nodeVector = NeuralUtils.elementwiseApplyTanh(W.mult(childrenVector));
}
} else {
log.info("SentimentCostAndGradient: warning: Tree not correctly binarized: " + tree);
throw new AssertionError("Tree not correctly binarized");
}
SimpleMatrix predictions = NeuralUtils.softmax(classification.mult(NeuralUtils.concatenateWithBias(nodeVector)));
int index = getPredictedClass(predictions);
if (!(tree.label() instanceof CoreLabel)) {
log.info("SentimentCostAndGradient: warning: No CoreLabels in nodes: " + tree);
throw new AssertionError("Expected CoreLabels in the nodes");
}
CoreLabel label = (CoreLabel) tree.label();
label.set(RNNCoreAnnotations.Predictions.class, predictions);
label.set(RNNCoreAnnotations.PredictedClass.class, index);
label.set(RNNCoreAnnotations.NodeVector.class, nodeVector);
} // end forwardPropagateTree
}