package edu.stanford.nlp.classify;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import edu.stanford.nlp.ling.BasicDatum;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.TwoDimensionalCounter;
import edu.stanford.nlp.objectbank.ObjectBank;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.ScoredComparator;
import edu.stanford.nlp.util.ScoredObject;
import edu.stanford.nlp.util.logging.Redwood;
/**
* An interfacing class for {@link ClassifierFactory} that incrementally
* builds a more memory-efficient representation of a {@link List} of
* {@link Datum} objects for the purposes of training a {@link Classifier}
* with a {@link ClassifierFactory}.
*
* @author Roger Levy (rog@stanford.edu)
* @author Anna Rafferty (various refactoring with GeneralDataset/RVFDataset)
* @author Sarah Spikes (sdspikes@cs.stanford.edu) (templatization)
* @author nmramesh@cs.stanford.edu {@link #getL1NormalizedTFIDFDatum(Datum, Counter) and #getL1NormalizedTFIDFDataset()}
*
* @param <L> Label type
* @param <F> Feature type
*/
public class Dataset<L, F> extends GeneralDataset<L, F> {
private static final long serialVersionUID = -3883164942879961091L;
final static Redwood.RedwoodChannels logger = Redwood.channels(Dataset.class);
public Dataset() {
this(10);
}
public Dataset(int numDatums) {
initialize(numDatums);
}
public Dataset(int numDatums, Index<F> featureIndex, Index<L> labelIndex) {
initialize(numDatums);
this.featureIndex = featureIndex;
this.labelIndex = labelIndex;
}
public Dataset(Index<F> featureIndex, Index<L> labelIndex) {
this(10, featureIndex, labelIndex);
}
/**
* Constructor that fully specifies a Dataset. Needed this for MulticlassDataset.
*/
public Dataset(Index<L> labelIndex, int[] labels, Index<F> featureIndex, int[][] data) {
this (labelIndex, labels, featureIndex, data, data.length);
}
/**
* Constructor that fully specifies a Dataset. Needed this for MulticlassDataset.
*/
public Dataset(Index<L> labelIndex, int[] labels, Index<F> featureIndex, int[][] data, int size) {
this.labelIndex = labelIndex;
this.labels = labels;
this.featureIndex = featureIndex;
this.data = data;
this.size = size;
}
/** {@inheritDoc} */
@Override
public Pair<GeneralDataset<L, F>, GeneralDataset<L, F>> split(double percentDev) {
return split(0, (int)(percentDev * size()));
}
/** {@inheritDoc} */
@Override
public Pair<GeneralDataset<L, F>,GeneralDataset<L, F>> split(int start, int end) {
int devSize = end - start;
int trainSize = size() - devSize;
int[][] devData = new int[devSize][];
int[] devLabels = new int[devSize];
int[][] trainData = new int[trainSize][];
int[] trainLabels = new int[trainSize];
synchronized (System.class) {
System.arraycopy(data, start, devData, 0, devSize);
System.arraycopy(labels, start, devLabels, 0, devSize);
System.arraycopy(data, 0, trainData, 0, start);
System.arraycopy(data, end, trainData, start, size() - end);
System.arraycopy(labels, 0, trainLabels, 0, start);
System.arraycopy(labels, end, trainLabels, start, size() - end);
}
if (this instanceof WeightedDataset<?,?>) {
float[] trainWeights = new float[trainSize];
float[] devWeights = new float[devSize];
WeightedDataset<L, F> w = (WeightedDataset<L, F>)this;
synchronized (System.class) {
System.arraycopy(w.weights, start, devWeights, 0, devSize);
System.arraycopy(w.weights, 0, trainWeights, 0, start);
System.arraycopy(w.weights, end, trainWeights, start, size() - end);
}
WeightedDataset<L, F> dev = new WeightedDataset<>(labelIndex, devLabels, featureIndex, devData, devSize, devWeights);
WeightedDataset<L, F> train = new WeightedDataset<>(labelIndex, trainLabels, featureIndex, trainData, trainSize, trainWeights);
return new Pair<>(train, dev);
}
Dataset<L, F> dev = new Dataset<>(labelIndex, devLabels, featureIndex, devData, devSize);
Dataset<L, F> train = new Dataset<>(labelIndex, trainLabels, featureIndex, trainData, trainSize);
return new Pair<>(train, dev);
}
public Dataset<L, F> getRandomSubDataset(double p, int seed) {
int newSize = (int)(p * size());
Set<Integer> indicesToKeep = Generics.newHashSet();
Random r = new Random(seed);
int s = size();
while (indicesToKeep.size() < newSize) {
indicesToKeep.add(r.nextInt(s));
}
int[][] newData = new int[newSize][];
int[] newLabels = new int[newSize];
int i = 0;
for (int j : indicesToKeep) {
newData[i] = data[j];
newLabels[i] = labels[j];
i++;
}
return new Dataset<>(labelIndex, newLabels, featureIndex, newData);
}
@Override
public double[][] getValuesArray() {
return null;
}
/**
* Constructs a Dataset by reading in a file in SVM light format.
*/
public static Dataset<String, String> readSVMLightFormat(String filename) {
return readSVMLightFormat(filename, new HashIndex<>(), new HashIndex<>());
}
/**
* Constructs a Dataset by reading in a file in SVM light format.
* The lines parameter is filled with the lines of the file for further processing
* (if lines is null, it is assumed no line information is desired)
*/
public static Dataset<String, String> readSVMLightFormat(String filename, List<String> lines) {
return readSVMLightFormat(filename, new HashIndex<>(), new HashIndex<>(), lines);
}
/**
* Constructs a Dataset by reading in a file in SVM light format.
* the created dataset has the same feature and label index as given
*/
public static Dataset<String, String> readSVMLightFormat(String filename, Index<String> featureIndex, Index<String> labelIndex) {
return readSVMLightFormat(filename, featureIndex, labelIndex, null);
}
/**
* Constructs a Dataset by reading in a file in SVM light format.
* the created dataset has the same feature and label index as given
*/
public static Dataset<String, String> readSVMLightFormat(String filename, Index<String> featureIndex, Index<String> labelIndex, List<String> lines) {
Dataset<String, String> dataset;
try {
dataset = new Dataset<>(10, featureIndex, labelIndex);
for (String line : ObjectBank.getLineIterator(new File(filename))) {
if(lines != null)
lines.add(line);
dataset.add(svmLightLineToDatum(line));
}
} catch (Exception e) {
throw new RuntimeException(e);
}
return dataset;
}
private static int line1 = 0;
public static Datum<String, String> svmLightLineToDatum(String l) {
line1++;
l = l.replaceAll("#.*", ""); // remove any trailing comments
String[] line = l.split("\\s+");
Collection<String> features = new ArrayList<>();
for (int i = 1; i < line.length; i++) {
String[] f = line[i].split(":");
if (f.length != 2) {
logger.info("Dataset error: line " + line1);
}
int val = (int) Double.parseDouble(f[1]);
for (int j = 0; j < val; j++) {
features.add(f[0]);
}
}
features.add(String.valueOf(Integer.MAX_VALUE)); // a constant feature for a class
Datum<String, String> d = new BasicDatum<>(features, line[0]);
return d;
}
/**
* Get Number of datums a given feature appears in.
*/
public Counter<F> getFeatureCounter()
{
Counter<F> featureCounts = new ClassicCounter<>();
for (int i=0; i < this.size(); i++)
{
BasicDatum<L, F> datum = (BasicDatum<L, F>) getDatum(i);
Set<F> featureSet = Generics.newHashSet(datum.asFeatures());
for (F key : featureSet) {
featureCounts.incrementCount(key, 1.0);
}
}
return featureCounts;
}
/**
* Method to convert features from counts to L1-normalized TFIDF based features
* @param datum with a collection of features.
* @param featureDocCounts a counter of doc-count for each feature.
* @return RVFDatum with l1-normalized tf-idf features.
*/
public RVFDatum<L,F> getL1NormalizedTFIDFDatum(Datum<L,F> datum,Counter<F> featureDocCounts){
Counter<F> tfidfFeatures = new ClassicCounter<>();
for(F feature : datum.asFeatures()){
if(featureDocCounts.containsKey(feature))
tfidfFeatures.incrementCount(feature,1.0);
}
double l1norm = 0;
for(F feature: tfidfFeatures.keySet()){
double idf = Math.log(((double)(this.size()+1))/(featureDocCounts.getCount(feature)+0.5));
double tf = tfidfFeatures.getCount(feature);
tfidfFeatures.setCount(feature, tf*idf);
l1norm += tf*idf;
}
for(F feature: tfidfFeatures.keySet()){
double tfidf = tfidfFeatures.getCount(feature);
tfidfFeatures.setCount(feature, tfidf/l1norm);
}
RVFDatum<L,F> rvfDatum = new RVFDatum<>(tfidfFeatures, datum.label());
return rvfDatum;
}
/**
* Method to convert this dataset to RVFDataset using L1-normalized TF-IDF features
* @return RVFDataset
*/
public RVFDataset<L,F> getL1NormalizedTFIDFDataset(){
RVFDataset<L,F> rvfDataset = new RVFDataset<>(this.size(), this.featureIndex, this.labelIndex);
Counter<F> featureDocCounts = getFeatureCounter();
for(int i = 0; i < this.size(); i++){
Datum<L,F> datum = this.getDatum(i);
RVFDatum<L,F> rvfDatum = getL1NormalizedTFIDFDatum(datum,featureDocCounts);
rvfDataset.add(rvfDatum);
}
return rvfDataset;
}
@Override
public void add(Datum<L, F> d) {
add(d.asFeatures(), d.label());
}
public void add(Collection<F> features, L label) {
add(features, label, true);
}
public void add(Collection<F> features, L label, boolean addNewFeatures) {
ensureSize();
addLabel(label);
addFeatures(features, addNewFeatures);
size++;
}
/**
* Adds a datums defined by feature indices and label index
* Careful with this one! Make sure that all indices are valid!
* @param features
* @param label
*/
public void add(int [] features, int label) {
ensureSize();
addLabelIndex(label);
addFeatureIndices(features);
size++;
}
protected void ensureSize() {
if (labels.length == size) {
int[] newLabels = new int[size * 2];
int[][] newData = new int[size * 2][];
synchronized (System.class) {
System.arraycopy(labels, 0, newLabels, 0, size);
System.arraycopy(data, 0, newData, 0, size);
}
labels = newLabels;
data = newData;
}
}
protected void addLabel(L label) {
labelIndex.add(label);
labels[size] = labelIndex.indexOf(label);
}
protected void addLabelIndex(int label) {
labels[size] = label;
}
protected void addFeatures(Collection<F> features) {
addFeatures(features, true);
}
protected void addFeatures(Collection<F> features, boolean addNewFeatures) {
int[] intFeatures = new int[features.size()];
int j = 0;
for (F feature : features) {
if(addNewFeatures) featureIndex.add(feature);
int index = featureIndex.indexOf(feature);
if (index >= 0) {
intFeatures[j] = featureIndex.indexOf(feature);
j++;
}
}
data[size] = new int[j];
synchronized (System.class) {
System.arraycopy(intFeatures, 0, data[size], 0, j);
}
}
protected void addFeatureIndices(int [] features) {
data[size] = features;
}
@Override
protected final void initialize(int numDatums) {
labelIndex = new HashIndex<>();
featureIndex = new HashIndex<>();
labels = new int[numDatums];
data = new int[numDatums][];
size = 0;
}
/**
* @return the index-ed datum
*/
@Override
public Datum<L, F> getDatum(int index) {
return new BasicDatum<>(featureIndex.objects(data[index]), labelIndex.get(labels[index]));
}
/**
* @return the index-ed datum
*/
@Override
public RVFDatum<L, F> getRVFDatum(int index) {
ClassicCounter<F> c = new ClassicCounter<>();
for (F key : featureIndex.objects(data[index])) {
c.incrementCount(key);
}
return new RVFDatum<>(c, labelIndex.get(labels[index]));
}
/**
* Prints some summary statistics to stderr for the Dataset.
*/
@Override
public void summaryStatistics() {
logger.info(toSummaryStatistics());
}
/** A String that is multiple lines of text giving summary statistics.
* (It does not end with a newline, though.)
*
* @return A textual summary of the Dataset
*/
public String toSummaryStatistics() {
StringBuilder sb = new StringBuilder();
sb.append("numDatums: ").append(size).append('\n');
sb.append("numDatumsPerLabel: ").append(this.numDatumsPerLabel()).append('\n');
sb.append("numLabels: ").append(labelIndex.size()).append(" [");
Iterator<L> iter = labelIndex.iterator();
while (iter.hasNext()) {
sb.append(iter.next());
if (iter.hasNext()) {
sb.append(", ");
}
}
sb.append("]\n");
sb.append("numFeatures (Phi(X) types): ").append(featureIndex.size()).append(" [");
int sz = Math.min(5, featureIndex.size());
for (int i = 0; i < sz; i++) {
if (i > 0) {
sb.append(", ");
}
sb.append(featureIndex.get(i));
}
if (sz < featureIndex.size()) {
sb.append(", ...");
}
sb.append(']');
return sb.toString();
}
/**
* Applies feature count thresholds to the Dataset.
* Only features that match pattern_i and occur at
* least threshold_i times (for some i) are kept.
*
* @param thresholds a list of pattern, threshold pairs
*/
public void applyFeatureCountThreshold(List<Pair<Pattern, Integer>> thresholds) {
// get feature counts
float[] counts = getFeatureCounts();
// build a new featureIndex
Index<F> newFeatureIndex = new HashIndex<>();
LOOP:
for (F f : featureIndex) {
for (Pair<Pattern, Integer> threshold : thresholds) {
Pattern p = threshold.first();
Matcher m = p.matcher(f.toString());
if (m.matches()) {
if (counts[featureIndex.indexOf(f)] >= threshold.second) {
newFeatureIndex.add(f);
}
continue LOOP;
}
}
// we only get here if it didn't match anything on the list
newFeatureIndex.add(f);
}
counts = null;
int[] featMap = new int[featureIndex.size()];
for (int i = 0; i < featMap.length; i++) {
featMap[i] = newFeatureIndex.indexOf(featureIndex.get(i));
}
featureIndex = null;
for (int i = 0; i < size; i++) {
List<Integer> featList = new ArrayList<>(data[i].length);
for (int j = 0; j < data[i].length; j++) {
if (featMap[data[i][j]] >= 0) {
featList.add(featMap[data[i][j]]);
}
}
data[i] = new int[featList.size()];
for (int j = 0; j < data[i].length; j++) {
data[i][j] = featList.get(j);
}
}
featureIndex = newFeatureIndex;
}
/**
* prints the full feature matrix in tab-delimited form. These can be BIG
* matrices, so be careful!
*/
public void printFullFeatureMatrix(PrintWriter pw) {
String sep = "\t";
for (int i = 0; i < featureIndex.size(); i++) {
pw.print(sep + featureIndex.get(i));
}
pw.println();
for (int i = 0; i < labels.length; i++) {
pw.print(labelIndex.get(i));
Set<Integer> feats = Generics.newHashSet();
for (int j = 0; j < data[i].length; j++) {
int feature = data[i][j];
feats.add(Integer.valueOf(feature));
}
for (int j = 0; j < featureIndex.size(); j++) {
if (feats.contains(Integer.valueOf(j))) {
pw.print(sep + '1');
} else {
pw.print(sep + '0');
}
}
}
}
/** {@inheritDoc} */
@Override
public void printSparseFeatureMatrix() {
printSparseFeatureMatrix(new PrintWriter(System.out, true));
}
/** {@inheritDoc} */
@Override
public void printSparseFeatureMatrix(PrintWriter pw) {
String sep = "\t";
for (int i = 0; i < size; i++) {
pw.print(labelIndex.get(labels[i]));
int[] datum = data[i];
for (int j : datum) {
pw.print(sep + featureIndex.get(j));
}
pw.println();
}
}
public void changeLabelIndex(Index<L> newLabelIndex) {
labels = trimToSize(labels);
for (int i = 0; i < labels.length; i++) {
labels[i] = newLabelIndex.indexOf(labelIndex.get(labels[i]));
}
labelIndex = newLabelIndex;
}
public void changeFeatureIndex(Index<F> newFeatureIndex) {
data = trimToSize(data);
labels = trimToSize(labels);
int[][] newData = new int[data.length][];
for (int i = 0; i < data.length; i++) {
int[] newD = new int[data[i].length];
int k = 0;
for (int j = 0; j < data[i].length; j++) {
int newIndex = newFeatureIndex.indexOf(featureIndex.get(data[i][j]));
if (newIndex >= 0) {
newD[k++] = newIndex;
}
}
newData[i] = new int[k];
synchronized (System.class) {
System.arraycopy(newD, 0, newData[i], 0, k);
}
}
data = newData;
featureIndex = newFeatureIndex;
}
public void selectFeaturesBinaryInformationGain(int numFeatures) {
double[] scores = getInformationGains();
selectFeatures(numFeatures,scores);
}
/**
* Generic method to select features based on the feature scores vector provided as an argument.
* @param numFeatures number of features to be selected.
* @param scores a vector of size total number of features in the data.
*/
public void selectFeatures(int numFeatures, double[] scores) {
List<ScoredObject<F>> scoredFeatures = new ArrayList<>();
for (int i = 0; i < scores.length; i++) {
scoredFeatures.add(new ScoredObject<>(featureIndex.get(i), scores[i]));
}
Collections.sort(scoredFeatures, ScoredComparator.DESCENDING_COMPARATOR);
Index<F> newFeatureIndex = new HashIndex<>();
for (int i = 0; i < scoredFeatures.size() && i < numFeatures; i++) {
newFeatureIndex.add(scoredFeatures.get(i).object());
//logger.info(scoredFeatures.get(i));
}
for (int i = 0; i < size; i++) {
int[] newData = new int[data[i].length];
int curIndex = 0;
for (int j = 0; j < data[i].length; j++) {
int index;
if ((index = newFeatureIndex.indexOf(featureIndex.get(data[i][j]))) != -1) {
newData[curIndex++] = index;
}
}
int[] newDataTrimmed = new int[curIndex];
synchronized (System.class) {
System.arraycopy(newData, 0, newDataTrimmed, 0, curIndex);
}
data[i] = newDataTrimmed;
}
featureIndex = newFeatureIndex;
}
public double[] getInformationGains() {
// assert size > 0;
// data = trimToSize(data); // Don't need to trim to size, and trimming is dangerous the dataset is empty (you can't add to it thereafter)
labels = trimToSize(labels);
// counts the number of times word X is present
ClassicCounter<F> featureCounter = new ClassicCounter<>();
// counts the number of time a document has label Y
ClassicCounter<L> labelCounter = new ClassicCounter<>();
// counts the number of times the document has label Y given word X is present
TwoDimensionalCounter<F,L> condCounter = new TwoDimensionalCounter<>();
for (int i = 0; i < labels.length; i++) {
labelCounter.incrementCount(labelIndex.get(labels[i]));
// convert the document to binary feature representation
boolean[] doc = new boolean[featureIndex.size()];
//logger.info(i);
for (int j = 0; j < data[i].length; j++) {
doc[data[i][j]] = true;
}
for (int j = 0; j < doc.length; j++) {
if (doc[j]) {
featureCounter.incrementCount(featureIndex.get(j));
condCounter.incrementCount(featureIndex.get(j), labelIndex.get(labels[i]), 1.0);
}
}
}
double entropy = 0.0;
for (int i = 0; i < labelIndex.size(); i++) {
double labelCount = labelCounter.getCount(labelIndex.get(i));
double p = labelCount / size();
entropy -= p * (Math.log(p) / Math.log(2));
}
double[] ig = new double[featureIndex.size()];
Arrays.fill(ig, entropy);
for (int i = 0; i < featureIndex.size(); i++) {
F feature = featureIndex.get(i);
double featureCount = featureCounter.getCount(feature);
double notFeatureCount = size() - featureCount;
double pFeature = featureCount / size();
double pNotFeature = (1.0 - pFeature);
if (featureCount == 0) { ig[i] = 0; continue; }
if (notFeatureCount == 0) { ig[i] = 0; continue; }
double sumFeature = 0.0;
double sumNotFeature = 0.0;
for (int j = 0; j < labelIndex.size(); j++) {
L label = labelIndex.get(j);
double featureLabelCount = condCounter.getCount(feature, label);
double notFeatureLabelCount = size() - featureLabelCount;
// yes, these dont sum to 1. that is correct.
// one is the prob of the label, given that the
// feature is present, and the other is the prob
// of the label given that the feature is absent
double p = featureLabelCount / featureCount;
double pNot = notFeatureLabelCount / notFeatureCount;
if (featureLabelCount != 0) {
sumFeature += p * (Math.log(p) / Math.log(2));
}
if (notFeatureLabelCount != 0) {
sumNotFeature += pNot * (Math.log(pNot) / Math.log(2));
}
//System.out.println(pNot+" "+(Math.log(pNot)/Math.log(2)));
}
//logger.info(pFeature+" * "+sumFeature+" = +"+);
//logger.info("^ "+pNotFeature+" "+sumNotFeature);
ig[i] += pFeature*sumFeature + pNotFeature*sumNotFeature;
/* earlier the line above used to be: ig[i] = pFeature*sumFeature + pNotFeature*sumNotFeature;
* This completely ignored the entropy term computed above. So added the "+=" to take that into account.
* -Ramesh (nmramesh@cs.stanford.edu)
*/
}
return ig;
}
public void updateLabels(int[] labels) {
if (labels.length != size())
throw new IllegalArgumentException(
"size of labels array does not match dataset size");
this.labels = labels;
}
@Override
public String toString() {
return "Dataset of size " + size;
}
public String toSummaryString() {
StringWriter sw = new StringWriter();
PrintWriter pw = new PrintWriter(sw);
pw.println("Number of data points: " + size());
pw.println("Number of active feature tokens: " + numFeatureTokens());
pw.println("Number of active feature types:" + numFeatureTypes());
return pw.toString();
}
/**
* Need to sort the counter by feature keys and dump it
*
*/
public static void printSVMLightFormat(PrintWriter pw, ClassicCounter<Integer> c, int classNo) {
Integer[] features = c.keySet().toArray(new Integer[c.keySet().size()]);
Arrays.sort(features);
StringBuilder sb = new StringBuilder();
sb.append(classNo);
sb.append(' ');
for (int f: features) {
sb.append(f + 1).append(':').append(c.getCount(f)).append(' ');
}
pw.println(sb.toString());
}
}