package org.wikibrain.matrix; import gnu.trove.map.TIntFloatMap; import gnu.trove.map.hash.TIntFloatHashMap; import org.junit.Ignore; import org.junit.Test; import java.io.File; import java.io.IOException; import java.util.Arrays; import java.util.LinkedHashMap; import java.util.Random; import static junit.framework.Assert.assertTrue; import static org.junit.Assert.assertEquals; public class TestSparseMatrixRow { private int[] keys = new int[] { 3, 9, 11, 26, 54 }; private float[] vals = new float[] {1.0f, 0.7f, 2.0f, 0.1f, -0.1f}; private int ROW_INDEX = 34; @Test public void testWrite() { MatrixRow row = createRow(); assertEquals(row.getRowIndex(), ROW_INDEX); assertEquals(row.getNumCols(), keys.length); for (int i = 0; i < keys.length; i++) { int k = row.getColIndex(i); float v = row.getColValue(i); float expected = vals[i]; // pinch it expected = Math.min(expected, SparseMatrixRow.MAX_SCORE); expected = Math.max(expected, SparseMatrixRow.MIN_SCORE); assertEquals(k, keys[i]); assertEquals(v, expected, 0.0001); } } @Test public void testSorting() { int maxColumns = 100000; Random random = new Random(); for (int i = 0; i < 100; i++) { TIntFloatMap vector = new TIntFloatHashMap(); int n = random.nextInt(maxColumns); for (int c = 0; c < n; c++) { vector.put(random.nextInt(), random.nextFloat()); } int keys[] = vector.keys(); float vals[] = vector.values(); SparseMatrixRow row = new SparseMatrixRow(new ValueConf(), ROW_INDEX, keys, vals); assertEquals(keys.length, row.getNumCols()); Arrays.sort(keys); for (int j = 0; j < row.getNumCols(); j++) { int k = row.getColIndex(j); float v = row.getColValue(j); assertEquals(k, keys[j]); assertTrue(vector.containsKey(k)); assertEquals(vector.get(k), v, 0.01); } } } public MatrixRow createRow() { LinkedHashMap<Integer, Float> m = new LinkedHashMap<Integer, Float>(); assertEquals(keys.length, vals.length); for (int i = 0; i < keys.length; i++) { m.put(keys[i], vals[i]); } return new SparseMatrixRow(new ValueConf(), ROW_INDEX, m); } }