package hex.optimization; import hex.DataInfo; import hex.glm.GLM.GLMGradientSolver; import hex.glm.GLMModel.GLMParameters; import hex.glm.GLMModel.GLMParameters.Family; import hex.glm.GLMModel.GLMWeightsFun; import hex.optimization.OptimizationUtils.GradientInfo; import hex.optimization.OptimizationUtils.GradientSolver; import org.junit.BeforeClass; import org.junit.Test; import water.*; import water.fvec.Frame; import water.util.ArrayUtils; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** * Created by tomasnykodym on 9/16/14. */ public class L_BFGS_Test extends TestUtil { @BeforeClass public static void setup() { stall_till_cloudsize(1); try { Thread.sleep(100); } catch (InterruptedException e) { e.printStackTrace(); } } // test on Rosenbrock's function (known optimum at (a,a^2), minimum = 0) @Test public void rosenbrock() { final double a = 1, b = 100; GradientSolver gs = new GradientSolver() { @Override public GradientInfo getGradient(double[] beta) { final double[] g = new double[2]; final double x = beta[0]; final double y = beta[1]; final double xx = x * x; g[0] = -2 * a + 2 * x - 4 * b * (y * x - x * xx); g[1] = 2 * b * (y - xx); double objVal = (a - x) * (a - x) + b * (y - xx) * (y - xx); return new GradientInfo(objVal, g); } @Override public GradientInfo getObjective(double[] beta) { return getGradient(beta); } }; L_BFGS lbfgs = new L_BFGS().setGradEps(1e-12); L_BFGS.Result r = lbfgs.solve(gs, L_BFGS.startCoefs(2, 987654321)); assertTrue("LBFGS failed to solve Rosenbrock function optimization",r.ginfo._objVal < 1e-4); } @Test public void logistic() { Key parsedKey = Key.make("prostate"); DataInfo dinfo = null; try { GLMParameters glmp = new GLMParameters(Family.binomial, Family.binomial.defaultLink); glmp._alpha = new double[]{0}; glmp._lambda = new double[]{1e-5}; Frame source = parse_test_file(parsedKey, "smalldata/glm_test/prostate_cat_replaced.csv"); source.add("CAPSULE", source.remove("CAPSULE")); source.remove("ID").remove(); Frame valid = new Frame(source._names.clone(),source.vecs().clone()); dinfo = new DataInfo(source, valid, 1, false, DataInfo.TransformType.STANDARDIZE, DataInfo.TransformType.NONE, true, false, false, /* weights */ false, /* offset */ false, /* fold */ false); DKV.put(dinfo._key,dinfo); glmp._obj_reg = 1/380.0; GLMGradientSolver solver = new GLMGradientSolver(null,glmp, dinfo, 1e-5, null); L_BFGS lbfgs = new L_BFGS().setGradEps(1e-8); double [] beta = MemoryManager.malloc8d(dinfo.fullN()+1); beta[beta.length-1] = new GLMWeightsFun(glmp).link(source.vec("CAPSULE").mean()); L_BFGS.Result r = lbfgs.solve(solver, beta, solver.getGradient(beta),new L_BFGS.ProgressMonitor(){ int _i = 0; public boolean progress(double [] beta, GradientInfo ginfo){ System.out.println(++_i +":" + ginfo._objVal + ", " + ArrayUtils.l2norm2(ginfo._gradient,false)); return true; } }); assertEquals(378.34, 2 * r.ginfo._objVal * source.numRows(), 1e-1); } finally { if(dinfo != null) DKV.remove(dinfo._key); Value v = DKV.get(parsedKey); if (v != null) { v.<Frame>get().delete(); } } } // Test LSM on arcene - wide dataset with ~10k columns // test warm start and max #iteratoions @Test public void testArcene() { Key parsedKey = Key.make("arcene_parsed"); DataInfo dinfo = null; try { Frame source = parse_test_file(parsedKey, "smalldata/glm_test/arcene.csv"); Frame valid = new Frame(source._names.clone(),source.vecs().clone()); GLMParameters glmp = new GLMParameters(Family.gaussian); glmp._lambda = new double[]{1e-5}; glmp._alpha = new double[]{0}; glmp._obj_reg = 0.01; dinfo = new DataInfo(source, valid, 1, false, DataInfo.TransformType.STANDARDIZE, DataInfo.TransformType.NONE, true, false, false, /* weights */ false, /* offset */ false, /* fold */ false); DKV.put(dinfo._key,dinfo); GradientSolver solver = new GLMGradientSolver(null,glmp, dinfo, 1e-5, null); L_BFGS lbfgs = new L_BFGS().setMaxIter(20); double [] beta = MemoryManager.malloc8d(dinfo.fullN()+1); beta[beta.length-1] = new GLMWeightsFun(glmp).link(source.lastVec().mean()); L_BFGS.Result r1 = lbfgs.solve(solver, beta.clone(), solver.getGradient(beta),new L_BFGS.ProgressMonitor(){ int _i = 0; public boolean progress(double [] beta, GradientInfo ginfo){ System.out.println(++_i +":" + ginfo._objVal); return true; } }); lbfgs.setMaxIter(50); final int iter = r1.iter; L_BFGS.Result r2 = lbfgs.solve(solver, r1.coefs, r1.ginfo, new L_BFGS.ProgressMonitor(){ int _i = 0; public boolean progress(double [] beta, GradientInfo ginfo){ System.out.println(iter + " + " + ++_i +":" + ginfo._objVal); return true; } }); System.out.println(); lbfgs = new L_BFGS().setMaxIter(100); L_BFGS.Result r3 = lbfgs.solve(solver, beta.clone(), solver.getGradient(beta),new L_BFGS.ProgressMonitor(){ int _i = 0; public boolean progress(double [] beta, GradientInfo ginfo){ System.out.println(++_i +":" + ginfo._objVal + ", " + ArrayUtils.l2norm2(ginfo._gradient,false)); return true; } }); assertEquals(r1.iter,20); // assertEquals (r1.iter + r2.iter,r3.iter); // should be equal? got mismatch by 2 assertEquals(r2.ginfo._objVal,r3.ginfo._objVal,1e-8); assertEquals( .5 * glmp._lambda[0] * ArrayUtils.l2norm(r3.coefs,true) + r3.ginfo._objVal, 1e-4, 5e-4); assertTrue("iter# expected < 100, got " + r3.iter, r3.iter < 100); } finally { if(dinfo != null) DKV.remove(dinfo._key); Value v = DKV.get(parsedKey); if (v != null) { v.<Frame>get().delete(); } } } }