package edu.stanford.nlp.coref.neural;
import java.io.Serializable;
import java.util.List;
import edu.stanford.nlp.neural.Embedding;
import edu.stanford.nlp.neural.NeuralUtils;
import org.ejml.simple.SimpleMatrix;
/**
* Stores the weights and implements the matrix operations used by a {@link NeuralCorefAlgorithm}
* @author Kevin Clark
*/
public class NeuralCorefModel implements Serializable {
private static final long serialVersionUID = 2139427931784505653L;
private final SimpleMatrix antecedentMatrix;
private final SimpleMatrix anaphorMatrix;
private final SimpleMatrix pairFeaturesMatrix;
private final SimpleMatrix pairwiseFirstLayerBias;
private final List<SimpleMatrix> anaphoricityModel;
private final List<SimpleMatrix> pairwiseModel;
private final Embedding wordEmbeddings;
public NeuralCorefModel(SimpleMatrix antecedentMatrix, SimpleMatrix anaphorMatrix,
SimpleMatrix pairFeaturesMatrix, SimpleMatrix pairwiseFirstLayerBias,
List<SimpleMatrix> anaphoricityModel, List<SimpleMatrix> pairwiseModel,
Embedding wordEmbeddings) {
this.antecedentMatrix = antecedentMatrix;
this.anaphorMatrix = anaphorMatrix;
this.pairFeaturesMatrix = pairFeaturesMatrix;
this.pairwiseFirstLayerBias = pairwiseFirstLayerBias;
this.anaphoricityModel = anaphoricityModel;
this.pairwiseModel = pairwiseModel;
this.wordEmbeddings = wordEmbeddings;
}
public double getAnaphoricityScore(SimpleMatrix mentionEmbedding,
SimpleMatrix anaphoricityFeatures) {
return score(NeuralUtils.concatenate(mentionEmbedding, anaphoricityFeatures),
anaphoricityModel);
}
public double getPairwiseScore(SimpleMatrix antecedentEmbedding, SimpleMatrix anaphorEmbedding,
SimpleMatrix pairFeatures) {
SimpleMatrix firstLayerOutput = NeuralUtils.elementwiseApplyReLU(
antecedentEmbedding
.plus(anaphorEmbedding)
.plus(pairFeaturesMatrix.mult(pairFeatures))
.plus(pairwiseFirstLayerBias));
return score(firstLayerOutput, pairwiseModel);
}
private static double score(SimpleMatrix features, List<SimpleMatrix> weights) {
for (int i = 0; i < weights.size(); i += 2) {
features = weights.get(i).mult(features).plus(weights.get(i + 1));
if (weights.get(i).numRows() > 1) {
features = NeuralUtils.elementwiseApplyReLU(features);
}
}
return features.elementSum();
}
public SimpleMatrix getAnaphorEmbedding(SimpleMatrix mentionEmbedding) {
return anaphorMatrix.mult(mentionEmbedding);
}
public SimpleMatrix getAntecedentEmbedding(SimpleMatrix mentionEmbedding) {
return antecedentMatrix.mult(mentionEmbedding);
}
public Embedding getWordEmbeddings() {
return wordEmbeddings;
}
}