package edu.stanford.nlp.ie.crf; import edu.stanford.nlp.math.ArrayMath; import edu.stanford.nlp.sequences.SeqClassifierFlags; /** * @author Mengqiu Wang */ public class NonLinearSecondOrderCliquePotentialFunction implements CliquePotentialFunction { private final double[][] inputLayerWeights4Edge; // first index is number of hidden units in layer one, second index is the input feature indices private final double[][] outputLayerWeights4Edge; // first index is the output class, second index is the number of hidden units private final double[][] inputLayerWeights; // first index is number of hidden units in layer one, second index is the input feature indices private final double[][] outputLayerWeights; // first index is the output class, second index is the number of hidden units private double[] layerOneCache, hiddenLayerCache; private double[] layerOneCache4Edge, hiddenLayerCache4Edge; private final SeqClassifierFlags flags; public NonLinearSecondOrderCliquePotentialFunction(double[][] inputLayerWeights4Edge, double[][] outputLayerWeights4Edge, double[][] inputLayerWeights, double[][] outputLayerWeights, SeqClassifierFlags flags) { this.inputLayerWeights4Edge = inputLayerWeights4Edge; this.outputLayerWeights4Edge = outputLayerWeights4Edge; this.inputLayerWeights = inputLayerWeights; this.outputLayerWeights = outputLayerWeights; this.flags = flags; } public double[] hiddenLayerOutput(double[][] inputLayerWeights, int[] nodeCliqueFeatures, SeqClassifierFlags aFlag, double[] featureVal, int cliqueSize) { double[] layerCache = null; double[] hlCache = null; int layerOneSize = inputLayerWeights.length; if (cliqueSize > 1) { if (layerOneCache4Edge == null || layerOneSize != layerOneCache4Edge.length) layerOneCache4Edge = new double[layerOneSize]; layerCache = layerOneCache4Edge; } else { if (layerOneCache == null || layerOneSize != layerOneCache.length) layerOneCache = new double[layerOneSize]; layerCache = layerOneCache; } for (int i = 0; i < layerOneSize; i++) { double[] ws = inputLayerWeights[i]; double lOneW = 0; double dotProd = 0; for (int m = 0; m < nodeCliqueFeatures.length; m++) { dotProd = ws[nodeCliqueFeatures[m]]; if (featureVal != null) dotProd *= featureVal[m]; lOneW += dotProd; } layerCache[i] = lOneW; } if (!aFlag.useHiddenLayer) return layerCache; // transform layer one through hidden if (cliqueSize > 1) { if (hiddenLayerCache4Edge == null || layerOneSize != hiddenLayerCache4Edge.length) hiddenLayerCache4Edge = new double[layerOneSize]; hlCache = hiddenLayerCache4Edge; } else { if (hiddenLayerCache == null || layerOneSize != hiddenLayerCache.length) hiddenLayerCache = new double[layerOneSize]; hlCache = hiddenLayerCache; } for (int i = 0; i < layerOneSize; i++) { if (aFlag.useSigmoid) { hlCache[i] = sigmoid(layerCache[i]); } else { hlCache[i] = Math.tanh(layerCache[i]); } } return hlCache; } private static double sigmoid(double x) { return 1 / (1 + Math.exp(-x)); } @Override public double computeCliquePotential(int cliqueSize, int labelIndex, int[] cliqueFeatures, double[] featureVal, int posInSent) { double output = 0.0; double[][] inputWeights, outputWeights = null; if (cliqueSize > 1) { inputWeights = inputLayerWeights4Edge; outputWeights = outputLayerWeights4Edge; } else { inputWeights = inputLayerWeights; outputWeights = outputLayerWeights; } double[] hiddenLayer = hiddenLayerOutput(inputWeights, cliqueFeatures, flags, featureVal, cliqueSize); int outputLayerSize = inputWeights.length / outputWeights[0].length; // transform the hidden layer to output layer through linear transformation if (flags.useOutputLayer) { double[] outputWs = null; if (flags.tieOutputLayer) { outputWs = outputWeights[0]; } else { outputWs = outputWeights[labelIndex]; } if (flags.softmaxOutputLayer) { outputWs = ArrayMath.softmax(outputWs); } for (int i = 0; i < inputWeights.length; i++) { if (flags.sparseOutputLayer || flags.tieOutputLayer) { if (i % outputLayerSize == labelIndex) { output += outputWs[ i / outputLayerSize ] * hiddenLayer[i]; } } else { output += outputWs[i] * hiddenLayer[i]; } } } else { output = hiddenLayer[labelIndex]; } return output; } }