package edu.stanford.nlp.classify; import edu.stanford.nlp.ling.BasicDatum; import edu.stanford.nlp.ling.Datum; import edu.stanford.nlp.ling.RVFDatum; import edu.stanford.nlp.stats.ClassicCounter; import junit.framework.Assert; import org.junit.Test; /** * Created by sonalg on 11/24/14. */ public class ShiftParamsLogisticClassifierITest { private static <L, F> BasicDatum<L, F> newDatum(L label, F[] features, Double[] counts) { ClassicCounter<F> counter = new ClassicCounter<F>(); for (int i = 0; i < features.length; i++) { counter.setCount(features[i], counts[i]); } return new BasicDatum<L, F>(counter.keySet(), label); } private static void testStrBinaryDatums(double d1f1, double d1f2, double d2f1, double d2f2) throws Exception { Dataset<String, String> trainData = new Dataset<String, String>(); Datum<String, String> d1 = newDatum("alpha", new String[]{"f1", "f2"}, new Double[]{d1f1, d1f2}); Datum<String, String> d2 = newDatum("beta", new String[]{"f1", "f2"}, new Double[]{d2f1, d2f2}); trainData.add(d1); trainData.add(d2); LogPrior prior = new LogPrior(LogPrior.LogPriorType.QUADRATIC, 1.0, 0.1); ShiftParamsLogisticClassifierFactory<String, String> lfc = new ShiftParamsLogisticClassifierFactory<String, String>(prior, 0.01); MultinomialLogisticClassifier<String, String> lc = lfc.trainClassifier(trainData); // Try the obvious (should get train data with 100% acc) Assert.assertEquals(d1.label(), lc.classOf(d1)); Assert.assertEquals(d2.label(), lc.classOf(d2)); } @Test public void testStrBinaryDatums() throws Exception { // testStrBinaryDatums(1.0, 0.0, 1.0, 0.0); // testStrBinaryDatums(1.0, 0.0, -1.0, 0.0); // testStrBinaryDatums(0.0, 1.0, 0.0, -1.0); // testStrBinaryDatums(0.0, -1.0, 0.0, 1.0); // testStrBinaryDatums(1.0, 1.0, -1.0, -1.0); // testStrBinaryDatums(0.0, 1.0, 1.0, 0.0); // testStrBinaryDatums(1.0, 0.0, 0.0, 1.0); } }