package edu.stanford.nlp.classify;
import java.util.Arrays;
import junit.framework.TestCase;
import edu.stanford.nlp.ling.BasicDatum;
import edu.stanford.nlp.util.Pair;
/**
* @author Christopher Manning
*/
public class GeneralDatasetTest extends TestCase {
public static void testCreateFolds() {
GeneralDataset<String, String> data = new Dataset<String, String>();
data.add(new BasicDatum<String, String>(Arrays.asList(new String[]{"fever", "cough", "congestion"}), "cold"));
data.add(new BasicDatum<String, String>(Arrays.asList(new String[]{"fever", "cough", "nausea"}), "flu"));
data.add(new BasicDatum<String, String>(Arrays.asList(new String[]{"cough", "congestion"}), "cold"));
data.add(new BasicDatum<String, String>(Arrays.asList(new String[]{"cough", "congestion"}), "cold"));
data.add(new BasicDatum<String, String>(Arrays.asList(new String[]{"fever", "nausea"}), "flu"));
data.add(new BasicDatum<String, String>(Arrays.asList(new String[]{"cough", "sore throat"}), "cold"));
Pair<GeneralDataset<String,String>,GeneralDataset<String,String>> devTrainTest =
data.split(3, 5);
assertEquals(4, devTrainTest.first().size());
assertEquals(2, devTrainTest.second().size());
assertEquals("cold", devTrainTest.first().getDatum(devTrainTest.first().size() - 1).label());
assertEquals("flu", devTrainTest.second().getDatum(devTrainTest.second().size() - 1).label());
Pair<GeneralDataset<String,String>,GeneralDataset<String,String>> devTrainTest2 =
data.split(0,2);
assertEquals(4, devTrainTest2.first().size());
assertEquals(2, devTrainTest2.second().size());
Pair<GeneralDataset<String,String>,GeneralDataset<String,String>> devTrainTest3 =
data.split(1.0/3.0);
assertEquals(devTrainTest2.first().size(), devTrainTest3.first().size());
assertEquals(devTrainTest2.first().labelIndex(), devTrainTest3.first().labelIndex());
assertEquals(devTrainTest2.second().size(), devTrainTest3.second().size());
assertTrue(Arrays.equals(devTrainTest2.first().labels, devTrainTest2.first().labels));
assertTrue(Arrays.equals(devTrainTest2.second().labels, devTrainTest2.second().labels));
data.add(new BasicDatum<String, String>(Arrays.asList(new String[]{"fever", "nausea"}), "flu"));
Pair<GeneralDataset<String,String>,GeneralDataset<String,String>> devTrainTest4 =
data.split(1.0/3.0);
assertEquals(5, devTrainTest4.first().size());
assertEquals(2, devTrainTest4.second().size());
Pair<GeneralDataset<String,String>,GeneralDataset<String,String>> devTrainTest5 =
data.split(1.0/8.0);
assertEquals(7, devTrainTest5.first().size());
assertEquals(0, devTrainTest5.second().size());
// Sonal did this, but I think she got it wrong and either should have past in test ratio or have taken p.second()
// double trainRatio = 0.9;
// Pair<GeneralDataset<String,String>,GeneralDataset<String,String>> p = data.split(0, (int) Math.floor(data.size() * trainRatio));
// assertEquals(6, p.first().size());
}
}