/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
/**
@author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a>
*/
package cc.mallet.types;
import java.util.logging.*;
import java.util.BitSet;
import java.io.*;
import cc.mallet.util.MalletLogger;
/* Where will the new features get extracted in the Pipe? */
public class FeatureInducer implements Serializable
{
private static Logger logger = MalletLogger.getLogger(FeatureInducer.class.getName());
static boolean addMaskedFeatures = false;
static int minTrainingListSize = 20;
// Only one of the following two will be non-null
RankedFeatureVector.Factory ranker;
RankedFeatureVector.PerLabelFactory perLabelRanker;
int beam1 = 300;
int beam2 = 1000;
FeatureConjunction.List fcl;
// xxx Could perhaps build a hash value for each feature that measures its distribution
// over instances, and avoid conjunctions of features that are *exact* duplicates
// with this hash value.
public FeatureInducer (RankedFeatureVector.Factory ranker,
InstanceList ilist,
int numNewFeatures, int beam1, int beam2)
{
this.fcl = new FeatureConjunction.List ();
this.beam1 = beam1;
this.beam2 = beam2;
if (ilist.size() < minTrainingListSize) {
logger.info ("FeatureInducer not inducing from less than "+minTrainingListSize+" features.");
return;
}
Alphabet tmpDV = (Alphabet) ilist.getDataAlphabet().clone();
FeatureSelection featuresSelected = ilist.getFeatureSelection();
InstanceList tmpilist = new InstanceList (tmpDV, ilist.getTargetAlphabet());
RankedFeatureVector gg = ranker.newRankedFeatureVector (ilist);
logger.info ("Rank values before this round of conjunction-building");
int n = Math.min (200, gg.numLocations());
for (int i = 0; i < n; i++)
logger.info ("Rank="+i+' '+Double.toString(gg.getValueAtRank(i)) + ' ' + gg.getObjectAtRank(i).toString());
//for (int i = gg.numLocations()-200; i < gg.numLocations(); i++)
//System.out.println ("i="+i+' '+Double.toString(gg.getValueAtRank(i)) + ' ' + gg.getObjectAtRank(i).toString());
//System.out.println ("");
FeatureSelection fsMin = new FeatureSelection (tmpDV);
FeatureSelection fsMax = new FeatureSelection (tmpDV);
int minBeam = Math.min (beam1, beam2);
int maxBeam = Math.max (beam1, beam2);
logger.info ("Using minBeam="+minBeam+" maxBeam="+maxBeam);
int max = maxBeam < gg.numLocations() ? maxBeam : gg.numLocations();
for (int b = 0; b < max; b++) {
if (gg.getValueAtRank(b) == 0)
break;
int index = gg.getIndexAtRank(b);
fsMax.add (index);
if (b < minBeam)
fsMin.add (index);
}
// Prevent it from searching through all of gg2
//double minGain = gg.getValueAtRank(maxBeam*2);
// No, there are so many "duplicate" features, that it ends up only adding a few each round.
//double minGain = Double.NEGATIVE_INFINITY;
// Just use a constant; anything less than this must not have enough support in the data.
//double minGain = 5;
double minGain = 0;
//// xxx Temporarily remove all feature conjunction pruning
//System.out.println ("FeatureInducer: Temporarily not pruning any feature conjunctions from consideration.");
//fsMin = fsMax = null; minGain = Double.NEGATIVE_INFINITY;
//int[] conjunctions = new int[beam];
//for (int b = 0; b < beam; b++)
//conjunctions[b] = gg.getIndexAtRank(b);
gg = null; // Allow memory to be freed
for (int i = 0; i < ilist.size(); i++) {
Instance inst = ilist.get(i);
FeatureVector fv = (FeatureVector) inst.getData ();
tmpilist.add (new Instance (new FeatureVector (fv, tmpDV, fsMin, fsMax),
inst.getTarget(), inst.getName(), inst.getSource()),
ilist.getInstanceWeight(i));
}
logger.info ("Calculating gradient gain of conjunctions, vocab size = "+tmpDV.size());
RankedFeatureVector gg2 = ranker.newRankedFeatureVector (tmpilist);
for (int i = 0; i < 200 && i < gg2.numLocations(); i++)
logger.info ("Conjunction Rank="+i+' '+Double.toString(gg2.getValueAtRank(i))
+ ' ' + gg2.getObjectAtRank(i).toString());
int numFeaturesAdded = 0;
Alphabet origV = ilist.getDataAlphabet();
int origVSize = origV.size();
nextfeatures:
for (int i = 0; i < gg2.numLocations(); i++) {
double gain = gg2.getValueAtRank (i);
if (gain < minGain) {
// There are no more new features we could add, because they all have no more gain
// than the features we started with
logger.info ("Stopping feature induction: gain["+i+"]="+gain+", minGain="+minGain);
break;
}
if (gg2.getIndexAtRank(i) >= origVSize) {
// First disjunct above so that we also add singleton features that are currently masked out
// xxx If addMaskedFeatures == true, we should still check the mask, so we don't
// "add" and print features that are already unmasked
String s = (String) gg2.getObjectAtRank(i);
int[] featureIndices = FeatureConjunction.getFeatureIndices(origV, s);
// Make sure that the new conjunction doesn't contain duplicate features
if (FeatureConjunction.isValidConjunction (featureIndices)
// Don't add features with exactly the same gain value: they are probably an
// "exactly overlapping duplicate"
// xxx Note that this might actually increase over-fitting!
&& (i == 0 || gg2.getValueAtRank(i-1) != gg2.getValueAtRank(i))
) {
double newFeatureValue = gg2.getValueAtRank(i);
// Don't add new conjunctions that have no more gain than any of their constituents
for (int j = 0; j < featureIndices.length; j++)
if (gg2.value (featureIndices[j]) >= newFeatureValue) {
//System.out.println ("Skipping feature that adds no gain "+newFeatureValue+' '+s);
continue nextfeatures;
}
fcl.add (new FeatureConjunction (origV, featureIndices));
int index = origV.size()-1;
// If we have a feature mask, be sure to include this new feature
logger.info ("Added feature c "+numFeaturesAdded+" "+newFeatureValue+ ' ' + s);
// xxx Also print the gradient here, if the feature already exists.
numFeaturesAdded++;
}
} else if (featuresSelected != null) {
int index = gg2.getIndexAtRank (i);
//System.out.println ("Atomic feature rank "+i+" at index "+index);
if (!featuresSelected.contains (index)
// A new atomic feature added to the FeatureSelection
// Don't add features with exactly the same gain value: they are probably an
// "exactly overlapping duplicate"
// xxx Note that this might actually increase over-fitting!
&& (i == 0 || gg2.getValueAtRank(i-1) != gg2.getValueAtRank(i))) {
fcl.add (new FeatureConjunction (origV, new int[] {index}));
logger.info ("Added feature a "+numFeaturesAdded+" "+gg2.getValueAtRank(i)+ ' ' + gg2.getObjectAtRank(i));
numFeaturesAdded++;
}
}
if (numFeaturesAdded >= numNewFeatures) {
logger.info ("Stopping feature induction: numFeaturesAdded="+numFeaturesAdded);
break;
}
}
logger.info ("Finished adding features");
}
public FeatureInducer (RankedFeatureVector.Factory ranker,
InstanceList ilist,
int numNewFeatures)
{
//this (ilist, classifications, numNewFeatures, 200, numNewFeatures);
//this (ilist, classifications, numNewFeatures, 200, 500);
this (ranker, ilist, numNewFeatures, numNewFeatures, numNewFeatures);
}
// This must be run on test instance lists before they can be transduced, because we have to add the right
// feature combinations!
public void induceFeaturesFor (InstanceList ilist,
boolean withFeatureShrinkage, boolean addPerClassFeatures)
{
assert (addPerClassFeatures == false);
assert (withFeatureShrinkage == false);
FeatureSelection fs = ilist.getFeatureSelection ();
assert (ilist.getPerLabelFeatureSelection() == null);
if (fcl.size() == 0)
return;
for (int i = 0; i < ilist.size(); i++) {
//System.out.println ("Induced features for instance #"+i);
Instance inst = ilist.get(i);
Object data = inst.getData ();
if (data instanceof AugmentableFeatureVector) {
AugmentableFeatureVector afv = (AugmentableFeatureVector) data;
fcl.addTo (afv, 1.0, fs);
} else if (data instanceof FeatureVectorSequence) {
FeatureVectorSequence fvs = (FeatureVectorSequence) data;
for (int j = 0; j < fvs.size(); j++)
fcl.addTo ((AugmentableFeatureVector) fvs.get(j), 1.0, fs);
} else {
throw new IllegalArgumentException ("Unsupported instance data type "+data.getClass().getName());
}
}
}
// Serialization
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 0;
private void writeObject (ObjectOutputStream out) throws IOException {
out.writeInt (CURRENT_SERIAL_VERSION);
out.writeInt(beam1);
out.writeInt(beam2);
out.writeObject(fcl);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
int version = in.readInt ();
beam1 = in.readInt();
beam2 = in.readInt();
fcl = (FeatureConjunction.List)in.readObject();
}
}