package edu.stanford.nlp.neural.rnn; import org.ejml.simple.SimpleMatrix; import edu.stanford.nlp.ling.CoreAnnotation; import edu.stanford.nlp.ling.CoreLabel; import edu.stanford.nlp.ling.Label; import edu.stanford.nlp.trees.Tree; /** Annotations used by Tree Recursive Neural Networks. * * @author John Bauer */ public class RNNCoreAnnotations { private RNNCoreAnnotations() {} // only static members /** * Used to denote the vector (distributed representation) at a particular node. * This stores a real vector that represents the semantics of a word or phrase. */ public static class NodeVector implements CoreAnnotation<SimpleMatrix> { @Override public Class<SimpleMatrix> getType() { return SimpleMatrix.class; } } /** * Get the vector (distributed representation) at a particular node. * * @param tree The tree node * @return The vector (distributed representation) of the given tree */ public static SimpleMatrix getNodeVector(Tree tree) { Label label = tree.label(); if (!(label instanceof CoreLabel)) { throw new IllegalArgumentException("CoreLabels required to get the attached node vector"); } return ((CoreLabel) label).get(NodeVector.class); } /** * Used to denote a vector of predictions at a particular node. * This is a vector of real values, typically the output of a softmax classification layer, * which gives the probabilities of each output value. */ public static class Predictions implements CoreAnnotation<SimpleMatrix> { @Override public Class<SimpleMatrix> getType() { return SimpleMatrix.class; } } public static SimpleMatrix getPredictions(Tree tree) { Label label = tree.label(); if (!(label instanceof CoreLabel)) { throw new IllegalArgumentException("CoreLabels required to get the attached predictions"); } return ((CoreLabel) label).get(Predictions.class); } /** * Get the argmax of the class predicteions. * The predicted classes can be an arbitrary set of non-negative integer classes, * but in our current sentiment models, the values used are on a 5-point * scale of 0 = very negative, 1 = negative, 2 = neutral, 3 = positive, * and 4 = very positive. */ public static class PredictedClass implements CoreAnnotation<Integer> { @Override public Class<Integer> getType() { return Integer.class; } } /** Return as an int the predicted class. If it is not defined for a node, * it will return -1 * * @return Either the sentiment level or -1 if none */ public static int getPredictedClass(Tree tree) { Label label = tree.label(); if (!(label instanceof CoreLabel)) { throw new IllegalArgumentException("CoreLabels required to get the attached predicted class"); } Integer val = ((CoreLabel) label).get(PredictedClass.class); return val == null ? -1: val; } /** * The index of the correct class. */ public static class GoldClass implements CoreAnnotation<Integer> { @Override public Class<Integer> getType() { return Integer.class; } } public static int getGoldClass(Tree tree) { Label label = tree.label(); if (!(label instanceof CoreLabel)) { throw new IllegalArgumentException("CoreLabels required to get the attached gold class"); } return ((CoreLabel) label).get(GoldClass.class); } public static void setGoldClass(Tree tree, int goldClass) { Label label = tree.label(); if (!(label instanceof CoreLabel)) { throw new IllegalArgumentException("CoreLabels required to set the attached gold class"); } ((CoreLabel) label).set(GoldClass.class, goldClass); } public static class PredictionError implements CoreAnnotation<Double> { @Override public Class<Double> getType() { return Double.class; } } public static double getPredictionError(Tree tree) { Label label = tree.label(); if (!(label instanceof CoreLabel)) { throw new IllegalArgumentException("CoreLabels required to get the attached prediction error"); } return ((CoreLabel) label).get(PredictionError.class); } public static void setPredictionError(Tree tree, double error) { Label label = tree.label(); if (!(label instanceof CoreLabel)) { throw new IllegalArgumentException("CoreLabels required to set the attached prediction error"); } ((CoreLabel) label).set(PredictionError.class, error); } }