/* Copyright (C) 2010 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.fst.semi_supervised; import java.util.ArrayList; import java.util.BitSet; import java.util.concurrent.Callable; import java.util.concurrent.Executors; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import cc.mallet.fst.CRF; import cc.mallet.fst.SumLattice; import cc.mallet.fst.SumLatticeDefault; import cc.mallet.fst.semi_supervised.constraints.GEConstraint; import cc.mallet.optimize.Optimizable; import cc.mallet.types.FeatureVectorSequence; import cc.mallet.types.Instance; import cc.mallet.types.InstanceList; import cc.mallet.types.MatrixOps; /** * Optimizable for CRF using Generalized Expectation constraints that * consider either a single label or a pair of labels of a linear chain CRF. * * See: * "Generalized Expectation Criteria for Semi-Supervised Learning of Conditional Random Fields" * Gideon Mann and Andrew McCallum * ACL 2008 * * @author Gregory Druck */ public class CRFOptimizableByGE implements Optimizable.ByGradientValue { private static final int DEFAULT_GPV = 10; private CRF crf; private ArrayList<GEConstraint> constraints; private InstanceList data; private int numThreads; private double gpv; private double weight; // indicator that keeps track of whether // the gradient / value need to be re-computed private int cache; private double cachedValue; private CRF.Factors cachedGradient; // lists of source states / transition indices // for each destination state // used in GELattice private int[][] reverseTrans; private int[][] reverseTransIndices; // instances in which at least one // constraint fires private BitSet instancesWithConstraints; private ThreadPoolExecutor executor; /** * @param crf CRF * @param constraints List of GEConstraints * @param data Unlabeled data. * @param map Map between states and labels. * @param numThreads Number of threads to use for training (DEFAULT=1) */ public CRFOptimizableByGE(CRF crf, ArrayList<GEConstraint> constraints, InstanceList data, StateLabelMap map, int numThreads) { this(crf,constraints,data,map,numThreads,1); } public CRFOptimizableByGE(CRF crf, ArrayList<GEConstraint> constraints, InstanceList data, StateLabelMap map, int numThreads, double weight) { this.crf = crf; this.constraints = constraints; this.cache = Integer.MAX_VALUE; this.cachedValue = Double.NaN; this.cachedGradient = new CRF.Factors(crf); this.data = data; this.numThreads = numThreads; this.weight = weight; instancesWithConstraints = new BitSet(data.size()); for (GEConstraint constraint : constraints) { constraint.setStateLabelMap(map); BitSet bitset = constraint.preProcess(data); instancesWithConstraints.or(bitset); } this.gpv = DEFAULT_GPV; if (numThreads > 1) { this.executor = (ThreadPoolExecutor)Executors.newFixedThreadPool(numThreads); } createReverseTransitionMatrices(crf); } /** * Initializes data structures for mapping between a * destination state and its source states / transition indices. * * @param crf CRF */ public void createReverseTransitionMatrices(CRF crf) { int[] counts = new int[crf.numStates()]; for (int si = 0; si < crf.numStates(); si++) { CRF.State prevState = (CRF.State)crf.getState(si); for (int di = 0; di < prevState.numDestinations(); di++) { int sj = prevState.getDestinationState(di).getIndex(); counts[sj]++; } } this.reverseTrans = new int[crf.numStates()][]; this.reverseTransIndices = new int[crf.numStates()][]; for (int i = 0; i < counts.length; i++) { this.reverseTrans[i] = new int[counts[i]]; this.reverseTransIndices[i] = new int[counts[i]]; } int[] indices = new int[crf.numStates()]; for (int si = 0; si < crf.numStates(); si++) { CRF.State prevState = (CRF.State)crf.getState(si); for (int di = 0; di < prevState.numDestinations(); di++) { int sj = prevState.getDestinationState(di).getIndex(); this.reverseTrans[sj][indices[sj]] = si; this.reverseTransIndices[sj][indices[sj]] = di; indices[sj]++; } } } public int getNumParameters() { return crf.getNumParameters(); } public void getParameters(double[] buffer) { crf.getParameters().getParameters(buffer); } public double getParameter(int index) { return crf.getParameters().getParameter(index); } public void setParameters(double[] params) { crf.getParameters().setParameters(params); crf.weightsValueChanged(); } public void setParameter(int index, double value) { crf.getParameters().setParameter(index, value); crf.weightsValueChanged(); } public void cacheValueAndGradient() { // compute and cache lattices //System.gc(); //System.err.println("Used Memory "+String.format("%.3f", (Runtime.getRuntime().totalMemory()-Runtime.getRuntime().freeMemory())/1000000.) + " before lattice"); ArrayList<SumLattice> lattices = new ArrayList<SumLattice>(); if (numThreads == 1) { for (int ii = 0; ii < data.size(); ii++) { if (instancesWithConstraints.get(ii)) { SumLatticeDefault lattice = new SumLatticeDefault( this.crf, (FeatureVectorSequence)data.get(ii).getData(), null, null, true); lattices.add(lattice); } else { lattices.add(null); } } } else { // mutli-threaded version ArrayList<Callable<Void>> tasks = new ArrayList<Callable<Void>>(); if (data.size() < numThreads) { numThreads = data.size(); } int increment = data.size() / numThreads; int start = 0; int end = increment; for (int thread = 0; thread < numThreads; thread++) { tasks.add(new SumLatticeTask(crf,data,instancesWithConstraints,start,end)); start += increment; if (thread == numThreads - 2) { end = data.size(); } else { end += increment; } } try { // run all threads and wait for them to finish executor.invokeAll(tasks); } catch (InterruptedException ie) { ie.printStackTrace(); } for (Callable<Void> task : tasks) { lattices.addAll(((SumLatticeTask)task).getLattices()); } assert(lattices.size() == data.size()) : lattices.size() + " " + data.size(); } System.err.println("Done computing lattices."); for (GEConstraint constraint : constraints) { constraint.zeroExpectations(); constraint.computeExpectations(lattices); } System.err.println("Done computing expectations."); //System.gc(); //System.err.println("Used Memory "+String.format("%.3f", (Runtime.getRuntime().totalMemory()-Runtime.getRuntime().freeMemory())/1000000.) + " after lattice"); // compute GE value this.cachedValue = 0; for (GEConstraint constraint : constraints) { this.cachedValue += constraint.getValue(); } cachedGradient.zero(); // compute GE gradient if (numThreads == 1) { for (int ii = 0; ii < data.size(); ii++) { if (instancesWithConstraints.get(ii)) { SumLattice lattice = lattices.get(ii); FeatureVectorSequence fvs = (FeatureVectorSequence)data.get(ii).getData(); new GELattice(fvs, lattice.getGammas(), lattice.getXis(), crf, reverseTrans, reverseTransIndices, cachedGradient,this.constraints, false); } } } else { // multi-threaded version ArrayList<Callable<Void>> tasks = new ArrayList<Callable<Void>>(); if (data.size() < numThreads) { numThreads = data.size(); } int increment = data.size() / numThreads; int start = 0; int end = increment; for (int thread = 0; thread < numThreads; thread++) { ArrayList<GEConstraint> constraintsCopy = new ArrayList<GEConstraint>(); for (GEConstraint constraint : constraints) { constraintsCopy.add(constraint.copy()); } tasks.add(new GELatticeTask(crf,data,lattices,constraintsCopy,instancesWithConstraints, reverseTrans,reverseTransIndices,start,end)); start += increment; if (thread == numThreads - 2) { end = data.size(); } else { end += increment; } } try { // run all threads and wait for them to finish executor.invokeAll(tasks); } catch (InterruptedException ie) { ie.printStackTrace(); } for (Callable<Void> task : tasks) { cachedGradient.plusEquals(((GELatticeTask)task).getGradient(), 1); } } System.err.println("Done computing gradient."); this.cachedValue += crf.getParameters().gaussianPrior(gpv); cachedGradient.plusEqualsGaussianPriorGradient(crf.getParameters(), gpv); System.err.println("Done computing regularization."); if (weight != 1) { this.cachedValue *= weight; } System.err.println("GE Value = " + this.cachedValue); } public void setGaussianPriorVariance(double variance) { this.gpv = variance; } public void getValueGradient(double[] buffer) { if (crf.getWeightsValueChangeStamp() != cache) { cacheValueAndGradient(); cache = crf.getWeightsValueChangeStamp(); } // TODO this will also multiply the prior, if active! cachedGradient.getParameters(buffer); if (weight != 1) { MatrixOps.timesEquals(buffer, weight); } } public double getValue() { if (crf.getWeightsValueChangeStamp() != cache) { cacheValueAndGradient(); cache = crf.getWeightsValueChangeStamp(); } return this.cachedValue; } /** * Should be called after training is complete * to shutdown all threads. */ public void shutdown() { if (executor == null) return; executor.shutdown(); try { executor.awaitTermination(30, TimeUnit.SECONDS); } catch (InterruptedException e) { e.printStackTrace(); } assert(executor.shutdownNow().size() == 0) : "All tasks didn't finish"; } } class SumLatticeTask implements Callable<Void> { private int start; private int end; private ArrayList<SumLatticeDefault> lattices; private InstanceList data; private CRF crf; private BitSet instancesWithConstraints; public SumLatticeTask(CRF crf, InstanceList data, BitSet instancesWithConstraints, int start, int end) { this.crf = crf; this.data = data; this.start = start; this.end = end; this.lattices = new ArrayList<SumLatticeDefault>(); this.instancesWithConstraints = instancesWithConstraints; } public ArrayList<SumLatticeDefault> getLattices() { return this.lattices; } public Void call() throws Exception { for (int ii = start; ii < end; ii++) { if (instancesWithConstraints.get(ii)) { Instance instance = data.get(ii); SumLatticeDefault lattice = new SumLatticeDefault( this.crf, (FeatureVectorSequence)instance.getData(), null, null, true); lattices.add(lattice); } else { lattices.add(null); } } return null; } } class GELatticeTask implements Callable<Void> { private int start; private int end; private ArrayList<GEConstraint> constraints; private ArrayList<SumLattice> lattices; private InstanceList data; private CRF crf; private CRF.Factors gradient; private BitSet instancesWithConstraints; private int[][] reverseTrans; private int[][] reverseTransIndices; /** * @param crf CRF * @param data Unlabeled data * @param lattices Cached SumLattices * @param constraints List of GEConstraints * @param instancesWithConstraints BitSet which indices whether any constraints fire for an instance * @param reverseTrans Source state indices for each destination state * @param reverseTransIndices Transition indices for each destination state * @param start Position in unlabeled data where this thread starts computing * @param end Position in unlabeled data where this thread stops computing */ public GELatticeTask(CRF crf, InstanceList data, ArrayList<SumLattice> lattices, ArrayList<GEConstraint> constraints, BitSet instancesWithConstraints, int[][] reverseTrans, int[][] reverseTransIndices, int start, int end) { this.crf = crf; this.data = data; this.lattices = lattices; this.constraints = constraints; this.start = start; this.end = end; this.gradient = new CRF.Factors(crf); this.instancesWithConstraints = instancesWithConstraints; this.reverseTrans = reverseTrans; this.reverseTransIndices = reverseTransIndices; } public CRF.Factors getGradient() { return this.gradient; } public Void call() throws Exception { for (int ii = start; ii < end; ii++) { if (instancesWithConstraints.get(ii)) { SumLattice lattice = lattices.get(ii); FeatureVectorSequence fvs = (FeatureVectorSequence)data.get(ii).getData(); new GELattice(fvs, lattice.getGammas(), lattice.getXis(), crf, reverseTrans, reverseTransIndices, gradient,this.constraints, false); } } return null; } }