package cc.mallet.fst; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import cc.mallet.optimize.Optimizer; import cc.mallet.optimize.OrthantWiseLimitedMemoryBFGS; import cc.mallet.types.InstanceList; /** * CRF trainer that implements L1-regularization. * * @author Kedar Bellare */ public class CRFTrainerByL1LabelLikelihood extends CRFTrainerByLabelLikelihood { static final double SPARSE_PRIOR = 0.0; double l1Weight = SPARSE_PRIOR; public CRFTrainerByL1LabelLikelihood(CRF crf) { this(crf, SPARSE_PRIOR); } /** * Constructor for CRF trainer. * * @param crf * CRF to train. * @param l1Weight * Weight of L1 term in objective (l1Weight*|w|). Higher L1 * weight means sparser solutions. */ public CRFTrainerByL1LabelLikelihood(CRF crf, double l1Weight) { super(crf); this.l1Weight = l1Weight; } public void setL1RegularizationWeight(double l1Weight) { this.l1Weight = l1Weight; } public Optimizer getOptimizer(InstanceList trainingSet) { getOptimizableCRF(trainingSet); if (opt == null || ocrf != opt.getOptimizable()) opt = new OrthantWiseLimitedMemoryBFGS(ocrf, l1Weight); return opt; } // Serialization private static final long serialVersionUID = 1L; private static final int CURRENT_SERIAL_VERSION = 0; private void writeObject(ObjectOutputStream out) throws IOException { out.writeInt(CURRENT_SERIAL_VERSION); out.writeDouble(l1Weight); } private void readObject(ObjectInputStream in) throws IOException { in.readInt(); // version l1Weight = in.readDouble(); } }