package edu.stanford.nlp.classify;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.*;
import edu.stanford.nlp.ling.BasicDatum;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
/**
* The purpose of this interface is to unify {@link Dataset} and {@link RVFDataset}.
* <p>
* Note: Despite these being value classes, at present there are no equals() and hashCode() methods
* defined so you just get the default ones from Object, so different objects aren't equal.
* </p>
*
* @author Kristina Toutanova (kristina@cs.stanford.edu)
* @author Anna Rafferty (various refactoring with subclasses)
* @author Sarah Spikes (sdspikes@cs.stanford.edu) (Templatization)
* @author Ramesh Nallapati (nmramesh@cs.stanford.edu)
* (added an abstract method getDatum, July 17th, 2008)
*
* @param <L> The type of the labels in the Dataset
* @param <F> The type of the features in the Dataset
*/
public abstract class GeneralDataset<L, F> implements Serializable, Iterable<RVFDatum<L, F>> {
private static final long serialVersionUID = 19157757130054829L;
public Index<L> labelIndex;
public Index<F> featureIndex;
protected int[] labels;
protected int[][] data;
protected int size;
public GeneralDataset() { }
public Index<L> labelIndex() { return labelIndex; }
public Index<F> featureIndex() { return featureIndex; }
public int numFeatures() { return featureIndex.size(); }
public int numClasses() { return labelIndex.size(); }
public int[] getLabelsArray() {
labels = trimToSize(labels);
return labels;
}
public int[][] getDataArray() {
data = trimToSize(data);
return data;
}
public abstract double[][] getValuesArray();
/**
* Resets the Dataset so that it is empty and ready to collect data.
*/
public void clear() {
clear(10);
}
/**
* Resets the Dataset so that it is empty and ready to collect data.
* @param numDatums initial capacity of dataset
*/
public void clear(int numDatums) {
initialize(numDatums);
}
/**
* This method takes care of resetting values of the dataset
* such that it is empty with an initial capacity of numDatums.
* Should be accessed only by appropriate methods within the class,
* such as clear(), which take care of other parts of the emptying of data.
*
* @param numDatums initial capacity of dataset
*/
protected abstract void initialize(int numDatums);
public abstract RVFDatum<L, F> getRVFDatum(int index);
public abstract Datum<L,F> getDatum(int index);
public abstract void add(Datum<L, F> d);
/**
* Get the total count (over all data instances) of each feature
*
* @return an array containing the counts (indexed by index)
*/
public float[] getFeatureCounts() {
float[] counts = new float[featureIndex.size()];
for (int i = 0, m = size; i < m; i++) {
for (int j = 0, n = data[i].length; j < n; j++) {
counts[data[i][j]] += 1.0;
}
}
return counts;
}
/**
* Applies a feature count threshold to the Dataset. All features that
* occur fewer than <i>k</i> times are expunged.
*/
public void applyFeatureCountThreshold(int k) {
float[] counts = getFeatureCounts();
Index<F> newFeatureIndex = new HashIndex<>();
int[] featMap = new int[featureIndex.size()];
for (int i = 0; i < featMap.length; i++) {
F feat = featureIndex.get(i);
if (counts[i] >= k) {
int newIndex = newFeatureIndex.size();
newFeatureIndex.add(feat);
featMap[i] = newIndex;
} else {
featMap[i] = -1;
}
// featureIndex.remove(feat);
}
featureIndex = newFeatureIndex;
// counts = null; // This is unnecessary; JVM can clean it up
for (int i = 0; i < size; i++) {
List<Integer> featList = new ArrayList<>(data[i].length);
for (int j = 0; j < data[i].length; j++) {
if (featMap[data[i][j]] >= 0) {
featList.add(featMap[data[i][j]]);
}
}
data[i] = new int[featList.size()];
for (int j = 0; j < data[i].length; j++) {
data[i][j] = featList.get(j);
}
}
}
/**
* Retains the given features in the Dataset. All features that
* do not occur in features are expunged.
*/
public void retainFeatures(Set<F> features) {
//float[] counts = getFeatureCounts();
Index<F> newFeatureIndex = new HashIndex<>();
int[] featMap = new int[featureIndex.size()];
for (int i = 0; i < featMap.length; i++) {
F feat = featureIndex.get(i);
if (features.contains(feat)) {
int newIndex = newFeatureIndex.size();
newFeatureIndex.add(feat);
featMap[i] = newIndex;
} else {
featMap[i] = -1;
}
// featureIndex.remove(feat);
}
featureIndex = newFeatureIndex;
// counts = null; // This is unnecessary; JVM can clean it up
for (int i = 0; i < size; i++) {
List<Integer> featList = new ArrayList<>(data[i].length);
for (int j = 0; j < data[i].length; j++) {
if (featMap[data[i][j]] >= 0) {
featList.add(featMap[data[i][j]]);
}
}
data[i] = new int[featList.size()];
for (int j = 0; j < data[i].length; j++) {
data[i][j] = featList.get(j);
}
}
}
/**
* Applies a max feature count threshold to the Dataset. All features that
* occur greater than <i>k</i> times are expunged.
*/
public void applyFeatureMaxCountThreshold(int k) {
float[] counts = getFeatureCounts();
HashIndex<F> newFeatureIndex = new HashIndex<>();
int[] featMap = new int[featureIndex.size()];
for (int i = 0; i < featMap.length; i++) {
F feat = featureIndex.get(i);
if (counts[i] <= k) {
int newIndex = newFeatureIndex.size();
newFeatureIndex.add(feat);
featMap[i] = newIndex;
} else {
featMap[i] = -1;
}
// featureIndex.remove(feat);
}
featureIndex = newFeatureIndex;
// counts = null; // This is unnecessary; JVM can clean it up
for (int i = 0; i < size; i++) {
List<Integer> featList = new ArrayList<>(data[i].length);
for (int j = 0; j < data[i].length; j++) {
if (featMap[data[i][j]] >= 0) {
featList.add(featMap[data[i][j]]);
}
}
data[i] = new int[featList.size()];
for (int j = 0; j < data[i].length; j++) {
data[i][j] = featList.get(j);
}
}
}
/**
* returns the number of feature tokens in the Dataset.
*/
public int numFeatureTokens() {
int x = 0;
for (int i = 0, m = size; i < m; i++) {
x += data[i].length;
}
return x;
}
/**
* returns the number of distinct feature types in the Dataset.
*/
public int numFeatureTypes() {
return featureIndex.size();
}
/**
* Adds all Datums in the given collection of data to this dataset
* @param data collection of datums you would like to add to the dataset
*/
public void addAll(Iterable<? extends Datum<L,F>> data) {
for (Datum<L, F> d : data) {
add(d);
}
}
/** Divide out a (devtest) split of the dataset versus the rest of it (as a training set).
*
* @param start Begin devtest with this index (inclusive)
* @param end End devtest before this index (exclusive)
* @return A Pair of data sets, the first being the remainder of size this.size() - (end-start)
* and the second being of size (end-start)
*/
public abstract Pair<GeneralDataset<L, F>, GeneralDataset<L, F>> split (int start, int end);
/** Divide out a (devtest) split from the start of the dataset and the rest of it (as a training set).
*
* @param fractionSplit The first fractionSplit of datums (rounded down) will be the second split
* @return A Pair of data sets, the first being the remainder of size ceiling(this.size() * (1-p)) drawn
* from the end of the dataset and the second of size floor(this.size() * p) drawn from the
* start of the dataset.
*/
public abstract Pair<GeneralDataset<L, F>, GeneralDataset<L, F>> split (double fractionSplit);
/** Divide out a (devtest) split of the dataset versus the rest of it (as a training set).
*
* @param fold The number of this fold (must be between 0 and (numFolds - 1)
* @param numFolds The number of folds to divide the data into (must be greater than or equal to the
* size of the data set)
* @return A Pair of data sets, the first being roughly (numFolds-1)/numFolds of the data items
* (for use as training data_, and the second being 1/numFolds of the data, taken from the
* fold<sup>th</sup> part of the data (for use as devTest data)
*/
public Pair<GeneralDataset<L, F>, GeneralDataset<L, F>> splitOutFold(int fold, int numFolds) {
if (numFolds < 2 || numFolds > size() || fold < 0 || fold >= numFolds) {
throw new IllegalArgumentException("Illegal request for fold " + fold + " of " + numFolds +
" on data set of size " + size());
}
int normalFoldSize = size()/numFolds;
int start = normalFoldSize * fold;
int end = start + normalFoldSize;
if (fold == (numFolds - 1)) {
end = size();
}
return split(start, end);
}
/**
* Returns the number of examples ({@link Datum}s) in the Dataset.
*/
public int size() {
return size;
}
protected void trimData() {
data = trimToSize(data);
}
protected void trimLabels() {
labels = trimToSize(labels);
}
protected int[] trimToSize(int[] i) {
int[] newI = new int[size];
synchronized (System.class) {
System.arraycopy(i, 0, newI, 0, size);
}
return newI;
}
protected int[][] trimToSize(int[][] i) {
int[][] newI = new int[size][];
synchronized (System.class) {
System.arraycopy(i, 0, newI, 0, size);
}
return newI;
}
protected double[][] trimToSize(double[][] i) {
double[][] newI = new double[size][];
synchronized (System.class) {
System.arraycopy(i, 0, newI, 0, size);
}
return newI;
}
/**
* Randomizes the data array in place.
* Note: this cannot change the values array or the datum weights,
* so redefine this for RVFDataset and WeightedDataset!
* This uses the Fisher-Yates (or Durstenfeld-Knuth) shuffle, which is unbiased.
* The same algorithm is used by shuffle() in j.u.Collections, and so you should get compatible
* results if using it on a Collection with the same seed (as of JDK1.7, at least).
*
* @param randomSeed A seed for the Random object (allows you to reproduce the same ordering)
*/
// todo: Probably should be renamed 'shuffle' to be consistent with Java Collections API
public void randomize(long randomSeed) {
Random rand = new Random(randomSeed);
for (int j = size - 1; j > 0; j--) {
// swap each item with some lower numbered item
int randIndex = rand.nextInt(j);
int[] tmp = data[randIndex];
data[randIndex] = data[j];
data[j] = tmp;
int tmpl = labels[randIndex];
labels[randIndex] = labels[j];
labels[j] = tmpl;
}
}
/**
* Randomizes the data array in place.
* Note: this cannot change the values array or the datum weights,
* so redefine this for RVFDataset and WeightedDataset!
* This uses the Fisher-Yates (or Durstenfeld-Knuth) shuffle, which is unbiased.
* The same algorithm is used by shuffle() in j.u.Collections, and so you should get compatible
* results if using it on a Collection with the same seed (as of JDK1.7, at least).
*
* @param randomSeed A seed for the Random object (allows you to reproduce the same ordering)
*/
public <E> void shuffleWithSideInformation(long randomSeed, List<E> sideInformation) {
if (size != sideInformation.size()) {
throw new IllegalArgumentException("shuffleWithSideInformation: sideInformation not of same size as Dataset");
}
Random rand = new Random(randomSeed);
for (int j = size - 1; j > 0; j--) {
// swap each item with some lower numbered item
int randIndex = rand.nextInt(j);
int[] tmp = data[randIndex];
data[randIndex] = data[j];
data[j] = tmp;
int tmpl = labels[randIndex];
labels[randIndex] = labels[j];
labels[j] = tmpl;
E tmpE = sideInformation.get(randIndex);
sideInformation.set(randIndex, sideInformation.get(j));
sideInformation.set(j, tmpE);
}
}
public GeneralDataset<L,F> sampleDataset(long randomSeed, double sampleFrac, boolean sampleWithReplacement) {
int sampleSize = (int)(this.size()*sampleFrac);
Random rand = new Random(randomSeed);
GeneralDataset<L,F> subset;
if (this instanceof RVFDataset) {
subset = new RVFDataset<>();
} else if (this instanceof Dataset) {
subset = new Dataset<>();
}
else {
throw new RuntimeException("Can't handle this type of GeneralDataset.");
}
if (sampleWithReplacement) {
for(int i = 0; i < sampleSize; i++){
int datumNum = rand.nextInt(this.size());
subset.add(this.getDatum(datumNum));
}
} else {
Set<Integer> indicedSampled = Generics.newHashSet();
while (subset.size() < sampleSize) {
int datumNum = rand.nextInt(this.size());
if (!indicedSampled.contains(datumNum)) {
subset.add(this.getDatum(datumNum));
indicedSampled.add(datumNum);
}
}
}
return subset;
}
/**
* Print some statistics summarizing the dataset
*
*/
public abstract void summaryStatistics();
/**
* Returns an iterator over the class labels of the Dataset
*
* @return An iterator over the class labels of the Dataset
*/
public Iterator<L> labelIterator() {
return labelIndex.iterator();
}
/**
*
* @param dataset
* @return a new GeneralDataset whose features and ids map exactly to those of this GeneralDataset.
* Useful when two Datasets are created independently and one wants to train a model on one dataset and test on the other. -Ramesh.
*/
public GeneralDataset<L,F> mapDataset(GeneralDataset<L,F> dataset){
GeneralDataset<L,F> newDataset;
if(dataset instanceof RVFDataset)
newDataset = new RVFDataset<>(this.featureIndex, this.labelIndex);
else newDataset = new Dataset<>(this.featureIndex, this.labelIndex);
this.featureIndex.lock();
this.labelIndex.lock();
//System.out.println("inside mapDataset: dataset size:"+dataset.size());
for(int i = 0; i < dataset.size(); i++)
//System.out.println("inside mapDataset: adding datum number"+i);
newDataset.add(dataset.getDatum(i));
//System.out.println("old Dataset stats: numData:"+dataset.size()+" numfeatures:"+dataset.featureIndex().size()+" numlabels:"+dataset.labelIndex.size());
//System.out.println("new Dataset stats: numData:"+newDataset.size()+" numfeatures:"+newDataset.featureIndex().size()+" numlabels:"+newDataset.labelIndex.size());
//System.out.println("this dataset stats: numData:"+size()+" numfeatures:"+featureIndex().size()+" numlabels:"+labelIndex.size());
this.featureIndex.unlock();
this.labelIndex.unlock();
return newDataset;
}
public static <L,L2,F> Datum<L2,F> mapDatum(Datum<L,F> d, Map<L,L2> labelMapping, L2 defaultLabel) {
// TODO: How to copy datum?
L2 newLabel = labelMapping.get(d.label());
if (newLabel == null) {
newLabel = defaultLabel;
}
if (d instanceof RVFDatum) {
return new RVFDatum<>(((RVFDatum<L, F>) d).asFeaturesCounter(), newLabel);
} else {
return new BasicDatum<>(d.asFeatures(), newLabel);
}
}
/**
*
* @param dataset
* @return a new GeneralDataset whose features and ids map exactly to those of this GeneralDataset. But labels are converted to be another set of labels
*/
public <L2> GeneralDataset<L2,F> mapDataset(GeneralDataset<L,F> dataset, Index<L2> newLabelIndex, Map<L,L2> labelMapping, L2 defaultLabel)
{
GeneralDataset<L2,F> newDataset;
if(dataset instanceof RVFDataset)
newDataset = new RVFDataset<>(this.featureIndex, newLabelIndex);
else newDataset = new Dataset<>(this.featureIndex, newLabelIndex);
this.featureIndex.lock();
this.labelIndex.lock();
//System.out.println("inside mapDataset: dataset size:"+dataset.size());
for(int i = 0; i < dataset.size(); i++) {
//System.out.println("inside mapDataset: adding datum number"+i);
Datum<L,F> d = dataset.getDatum(i);
Datum<L2,F> d2 = mapDatum(d, labelMapping, defaultLabel);
newDataset.add(d2);
}
//System.out.println("old Dataset stats: numData:"+dataset.size()+" numfeatures:"+dataset.featureIndex().size()+" numlabels:"+dataset.labelIndex.size());
//System.out.println("new Dataset stats: numData:"+newDataset.size()+" numfeatures:"+newDataset.featureIndex().size()+" numlabels:"+newDataset.labelIndex.size());
//System.out.println("this dataset stats: numData:"+size()+" numfeatures:"+featureIndex().size()+" numlabels:"+labelIndex.size());
this.featureIndex.unlock();
this.labelIndex.unlock();
return newDataset;
}
/**
* Dumps the Dataset as a training/test file for SVMLight. <br>
* class [fno:val]+
* The features must occur in consecutive order.
*/
public void printSVMLightFormat() {
printSVMLightFormat(new PrintWriter(System.out));
}
/**
* Maps our labels to labels that are compatible with svm_light
* @return array of strings
*/
public String[] makeSvmLabelMap() {
String[] labelMap = new String[numClasses()];
if (numClasses() > 2) {
for (int i = 0; i < labelMap.length; i++) {
labelMap[i] = String.valueOf((i + 1));
}
} else {
labelMap = new String[]{"+1", "-1"};
}
return labelMap;
}
// todo: Fix javadoc, have unit tested
/**
* Print SVM Light Format file.
*
* The following comments are no longer applicable because I am
* now printing out the exact labelID for each example. -Ramesh (nmramesh@cs.stanford.edu) 12/17/2009.
*
* If the Dataset has more than 2 classes, then it
* prints using the label index (+1) (for svm_struct). If it is 2 classes, then the labelIndex.get(0)
* is mapped to +1 and labelIndex.get(1) is mapped to -1 (for svm_light).
*/
public void printSVMLightFormat(PrintWriter pw) {
//assumes each data item has a few features on, and sorts the feature keys while collecting the values in a counter
// old comment:
// the following code commented out by Ramesh (nmramesh@cs.stanford.edu) 12/17/2009.
// why not simply print the exact id of the label instead of mapping to some values??
// new comment:
// mihai: we NEED this, because svm_light has special conventions not supported by default by our labels,
// e.g., in a multiclass setting it assumes that labels start at 1 whereas our labels start at 0 (08/31/2010)
String[] labelMap = makeSvmLabelMap();
for (int i = 0; i < size; i++) {
RVFDatum<L, F> d = getRVFDatum(i);
Counter<F> c = d.asFeaturesCounter();
ClassicCounter<Integer> printC = new ClassicCounter<>();
for (F f : c.keySet()) {
printC.setCount(featureIndex.indexOf(f), c.getCount(f));
}
Integer[] features = printC.keySet().toArray(new Integer[printC.keySet().size()]);
Arrays.sort(features);
StringBuilder sb = new StringBuilder();
sb.append(labelMap[labels[i]]).append(' ');
// sb.append(labels[i]).append(' '); // commented out by mihai: labels[i] breaks svm_light conventions!
/* Old code: assumes that F is Integer....
*
for (int f: features) {
sb.append((f + 1)).append(":").append(c.getCount(f)).append(" ");
}
*/
//I think this is what was meant (using printC rather than c), but not sure
// ~Sarah Spikes (sdspikes@cs.stanford.edu)
for (int f: features) {
sb.append((f + 1)).append(':').append(printC.getCount(f)).append(' ');
}
pw.println(sb.toString());
}
}
public Iterator<RVFDatum<L, F>> iterator() {
return new Iterator<RVFDatum<L,F>>() {
private int id; // = 0;
public boolean hasNext() {
return id < size();
}
public RVFDatum<L, F> next() {
if (id >= size()) {
throw new NoSuchElementException();
}
return getRVFDatum(id++);
}
public void remove() {
throw new UnsupportedOperationException();
}
};
}
public ClassicCounter<L> numDatumsPerLabel(){
labels = trimToSize(labels);
ClassicCounter<L> numDatums = new ClassicCounter<>();
for(int i : labels){
numDatums.incrementCount(labelIndex.get(i));
}
return numDatums;
}
/**
* Prints the sparse feature matrix using
* {@link #printSparseFeatureMatrix(PrintWriter)} to {@link System#out
* System.out}.
*/
public abstract void printSparseFeatureMatrix();
/**
* prints a sparse feature matrix representation of the Dataset. Prints the actual
* {@link Object#toString()} representations of features.
*/
public abstract void printSparseFeatureMatrix(PrintWriter pw);
}