package edu.stanford.nlp.classify;
import java.util.Arrays;
import edu.stanford.nlp.stats.Counter;
import junit.framework.TestCase;
import edu.stanford.nlp.ling.BasicDatum;
import edu.stanford.nlp.ling.Datum;
/** @author Christopher Manning */
public class DatasetTest extends TestCase {
public static void testDataset() {
Dataset<String, String> data = new Dataset<String, String>();
data.add(new BasicDatum<String, String>(Arrays.asList(new String[]{"fever", "cough", "congestion"}), "cold"));
data.add(new BasicDatum<String, String>(Arrays.asList(new String[]{"fever", "cough", "nausea"}), "flu"));
data.add(new BasicDatum<String, String>(Arrays.asList(new String[]{"cough", "congestion"}), "cold"));
// data.summaryStatistics();
assertEquals(4, data.numFeatures());
assertEquals(4, data.numFeatureTypes());
assertEquals(2, data.numClasses());
assertEquals(8, data.numFeatureTokens());
assertEquals(3, data.size());
data.applyFeatureCountThreshold(2);
assertEquals(3, data.numFeatures());
assertEquals(3, data.numFeatureTypes());
assertEquals(2, data.numClasses());
assertEquals(7, data.numFeatureTokens());
assertEquals(3, data.size());
//Dataset data = Dataset.readSVMLightFormat(args[0]);
//double[] scores = data.getInformationGains();
//System.out.println(ArrayMath.mean(scores));
//System.out.println(ArrayMath.variance(scores));
LinearClassifierFactory<String, String> factory = new LinearClassifierFactory<String, String>();
LinearClassifier<String, String> classifier = factory.trainClassifier(data);
Datum<String, String> d = new BasicDatum<String, String>(Arrays.asList(new String[]{"cough", "fever"}));
assertEquals("Classification incorrect", "flu", classifier.classOf(d));
Counter<String> probs = classifier.probabilityOf(d);
assertEquals("Returned probability incorrect", 0.4553, probs.getCount("cold"), 0.0001);
assertEquals("Returned probability incorrect", 0.5447, probs.getCount("flu"), 0.0001);
System.out.println();
}
}