/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.classify; import java.util.ArrayList; import cc.mallet.classify.constraints.pr.MaxEntPRConstraint; import cc.mallet.pipe.Pipe; import cc.mallet.types.FeatureVector; import cc.mallet.types.Instance; import cc.mallet.types.LabelVector; import cc.mallet.types.MatrixOps; /** * Auxiliary model (q) for E-step/I-projection in PR training. * * @author Gregory Druck <a href="mailto:gdruck@cs.umass.edu">gdruck@cs.umass.edu</a> */ public class PRAuxClassifier extends Classifier { private static final long serialVersionUID = 1L; private int numLabels; private double[][] parameters; private ArrayList<MaxEntPRConstraint> constraints; public PRAuxClassifier(Pipe pipe, ArrayList<MaxEntPRConstraint> constraints) { super(pipe); this.constraints = constraints; this.parameters = new double[constraints.size()][]; for (int i = 0; i < constraints.size(); i++) { this.parameters[i] = new double[constraints.get(i).numDimensions()]; } this.numLabels = pipe.getTargetAlphabet().size(); } public void getClassificationScores(Instance instance, double[] scores) { FeatureVector input = (FeatureVector)instance.getData(); for (MaxEntPRConstraint feature : constraints) { feature.preProcess(input); } for (int li = 0; li < numLabels; li++) { int ci = 0; for (MaxEntPRConstraint feature : constraints) { scores[li] += feature.getScore(input, li, parameters[ci]); ci++; } } } public void getClassificationProbs(Instance instance, double[] scores) { getClassificationScores(instance,scores); MatrixOps.expNormalize(scores); } @Override public Classification classify(Instance instance) { double[] scores = new double[numLabels]; getClassificationScores(instance,scores); return new Classification (instance, this, new LabelVector (getLabelAlphabet(), scores)); } public double[][] getParameters() { return parameters; } public ArrayList<MaxEntPRConstraint> getConstraintFeatures() { return constraints; } public void zeroExpectations() { for (MaxEntPRConstraint constraint : constraints) { constraint.zeroExpectations(); } } }