package org.shanbo.feluca.classification.lr; import java.io.FileReader; import java.io.IOException; import java.util.Properties; import org.shanbo.feluca.classification.common.Evaluator; import org.shanbo.feluca.data2.DataEntry; import org.shanbo.feluca.paddle.common.Utilities; public class TestSGDL2LR { public static void testTrain(String model) throws Exception{ AbstractSGDLogisticRegression lr = new SGDL2LR(); Properties p = new Properties(); p.setProperty("alpha", "0.1"); p.setProperty("lambda", "0.1"); p.setProperty("loops", "15"); p.setProperty("-w1", "2"); p.setProperty("w0type", "0"); lr.setProperties(p); lr.loadData(DataEntry.shuffledDataEntry("/home/lgn/data/avazutrain33")); // lr.crossValidation(4, new Evaluator.BinaryAccuracy()); System.out.println(lr.toString()); lr.train(); lr.saveModel(model); System.out.println(lr.toString()); } public static void testTest(String model,String predict) throws Exception{ SGDL2LR lr = new SGDL2LR(); Properties p = new Properties(); p.load(new FileReader("/home/lgn/data/avazutrain33/avazutrain33.sta")); lr.loadModel(model, p); DataEntry testSet = DataEntry.createDataEntry("/home/lgn/data/avazutest33", false); lr.predict(testSet, predict, new Evaluator.BinaryAccuracy()); } public static void main(String[] args) throws Exception { String model = "/home/lgn/kaggle/avazu.model"; String predict = "/home/lgn/kaggle/avazu.predict"; testTrain(model); System.out.println("==============================="); testTest(model, predict); } }