package hex; import hex.FrameTask.DataInfo; import hex.glm.GLMModel; import hex.gram.Gram.GramTask; import hex.la.DMatrix; import org.junit.Test; import water.*; import water.fvec.*; import water.fvec.Frame; import water.fvec.NFSFileVec; import water.fvec.ParseDataset2; import water.fvec.RebalanceDataSet; import water.util.Utils; import java.io.File; import java.util.Arrays; import static junit.framework.Assert.assertEquals; import static junit.framework.Assert.assertTrue; /** * Created by tomasnykodym on 11/14/14. */ public class MatrixTest extends TestUtil { private static Frame getFrameForFile(Key outputKey, String path){ File f = TestUtil.find_test_file(path); Key k = NFSFileVec.make(f); Frame fr = ParseDataset2.parse(outputKey, new Key[]{k}); return fr; } @Test public void testTranspose(){ Futures fs = new Futures(); Key parsed = Key.make("prostate_parsed"); Key modelKey = Key.make("prostate_model"); GLMModel model = null; File f = TestUtil.find_test_file("smalldata/glm_test/prostate_cat_replaced.csv"); Frame fr = getFrameForFile(parsed, "smalldata/glm_test/prostate_cat_replaced.csv"); fr.remove("RACE").remove(fs); Key k = Key.make("rebalanced"); H2O.submitTask(new RebalanceDataSet(fr, k, 64)).join(); fr.delete(); fr = DKV.get(k).get(); Frame tr = DMatrix.transpose(fr); tr.reloadVecs(); for(int i = 0; i < fr.numRows(); ++i) for(int j = 0; j < fr.numCols(); ++j) assertEquals(fr.vec(j).at(i),tr.vec(i).at(j),1e-4); fr.delete(); for(Vec v:tr.vecs()) v.remove(fs); fs.blockForPending(); // checkLeakedKeys(); } @Test // simple small & dense, compare t(X) %*% X against gram computed by glm task. public void testMultiplication(){ Key parsed = Key.make("prostate_parsed"); Futures fs = new Futures(); Frame fr = getFrameForFile(parsed, "smalldata/glm_test/prostate_cat_replaced.csv"); fr.remove("RACE").remove(fs); Key k = Key.make("rebalanced"); H2O.submitTask(new RebalanceDataSet(fr, k, 64)).join(); fr.delete(); fr = DKV.get(k).get(); Frame tr = DMatrix.transpose(fr); tr.reloadVecs(); Frame z = DMatrix.mmul(tr,fr); DataInfo dinfo = new DataInfo(fr, 0, false, false, DataInfo.TransformType.NONE); GramTask gt = new GramTask(null, dinfo, false,false).doAll(dinfo._adaptedFrame); gt._gram.mul(gt._nobs); double [][] gram = gt._gram.getDenseXX(); for(int i = 0; i < gram.length; ++i) for(int j = 0; j < gram[i].length; ++j) assertEquals("position " + i + ", " + j, gram[i][j], z.vec(j).at(i),1e-4); fr.delete(); for(Vec v:tr.vecs()) v.remove(fs); for(Vec v:z.vecs()) v.remove(fs); // for(Vec v:z2.vecs()) // v.remove(fs); fs.blockForPending(); checkLeakedKeys(); } // @Test // bigger & sparse, compare X2 <- H2 %*% M2 against R // public void testMultiplicationSparse() { // Futures fs = new Futures(); // Key xParsed = Key.make("xParsed"), hParsed = Key.make("hParsed"), mParsed = Key.make("mParsed"); // Frame C = getFrameForFile(xParsed, "smalldata/sparse_matrices/C.svm"); // C.remove(0).remove(fs); // Frame A = getFrameForFile(hParsed, "smalldata/sparse_matrices/A.svm"); // A.remove(0).remove(fs); // Frame B = getFrameForFile(mParsed, "smalldata/sparse_matrices/B.svm"); // B.remove(0).remove(fs); // Frame C2 = DMatrix.mmul(A,B); // for(int i = 0; i < C.numRows(); ++i) // for(int j = 0; j < C.numCols(); ++j) // we match only up to 1e-3? // assertEquals("@ " + i + ", " + j + " " + C.vec(j).at(i) + " != " + C2.vec(j).at(i), C.vec(j).at(i),C2.vec(j).at(i),1e-3); // C.delete(); // A.delete(); // B.delete(); // for(Vec v:C2.vecs()) // v.remove(fs); // fs.blockForPending(); // checkLeakedKeys(); // } @Test public void testTransposeSparse(){ Key parsed = Key.make("arcene_parsed"); GLMModel model = null; String[] data = new String[] { "1 2:.2 5:.5 9:.9\n-1 1:.1 4:.4 8:.8\n", "1 2:.2 5:.5 9:.9\n1 3:.3 6:.6\n", "-1 7:.7 8:.8 9:.9\n1 20:2.\n", "+1 1:.1 5:.5 6:.6 10:1\n1 19:1.9\n", "1 2:.2 5:.5 9:.9\n-1 1:.1 4:.4 8:.8\n", "1 2:.2 5:.5 9:.9\n1 3:.3 6:.6\n", "-1 7:.7 8:.8 9:.9\n1 20:2.\n", "+1 1:.1 5:.5 6:.6 10:1\n1 19:1.9\n", "1 2:.2 5:.5 9:.9\n-1 1:.1 4:.4 8:.8\n", "1 2:.2 5:.5 9:.9\n1 3:.3 6:.6\n", "-1 7:.7 8:.8 9:.9\n1 20:2.\n", "+1 1:.1 5:.5 6:.6 10:1\n1 19:1.9\n" }; Key k = FVecTest.makeByteVec(Key.make("svmtest_bits").toString(),data); Frame fr = ParseDataset2.parse(parsed, new Key[]{k}); Frame tr = DMatrix.transpose(fr); tr.reloadVecs(); for(int i = 0; i < fr.numRows(); ++i) for(int j = 0; j < fr.numCols(); ++j) assertEquals("at " + i + ", " + j + ":",fr.vec(j).at(i),tr.vec(i).at(j),1e-4); fr.delete(); Futures fs = new Futures(); for(Vec v:tr.vecs()) v.remove(fs); fs.blockForPending(); // checkLeakedKeys(); } }