package org.numenta.nupic.util;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.util.Arrays;
import org.junit.Test;
import gnu.trove.list.array.TIntArrayList;
public class AbstractSparseBinaryMatrixTest {
public AbstractSparseBinaryMatrix getTestMatrix() {
AbstractSparseBinaryMatrix matrix = new AbstractSparseBinaryMatrix(new int[] { 2, 2, 2 }) {
private static final long serialVersionUID = 1L;
private int value1;
private int value2;
@Override
public AbstractFlatMatrix<Integer> set(int index, Object value) {
// TODO Auto-generated method stub
return null;
}
@Override
public AbstractSparseBinaryMatrix setForTest(int index, int value) {
// TODO Auto-generated method stub
if(index == 0) value1 = value;
else value2 = value;
return this;
}
@Override
public AbstractSparseBinaryMatrix set(int value, int... coordinates) {
// TODO Auto-generated method stub
if(Arrays.toString(coordinates).equals("[0, 0, 0]")) {
value1 = value;
if(value1 == 1) {
setTrueCount(0, 1);
}
}else{
value2 = value;
if(value2 == 1) {
setTrueCount(1, 1);
}
}
return this;
}
@Override
public void rightVecSumAtNZ(int[] inputVector, int[] results, double stimulusThreshold) {
// TODO Auto-generated method stub
}
@Override
public void rightVecSumAtNZ(int[] inputVector, int[] results) {
// TODO Auto-generated method stub
}
@Override
public Object getSlice(int... coordinates) {
// TODO Auto-generated method stub
return null;
}
@Override
public Integer get(int index) {
return index == 0 ? value1 : value2;
}
};
return matrix;
}
@Test
public void testGetSliceIndexes() {
// Test handling of out of bounds coordinates
AbstractSparseBinaryMatrix matrix = getTestMatrix();
try {
matrix.getSliceIndexes(new int[] { 1, 1, 1 });
fail();
}catch(Exception e) {
assertEquals("This method only returns the array holding the specified maximum index: [2, 2, 2]", e.getMessage());
}
int[] sliceIndexes = matrix.getSliceIndexes(new int[] { 0 });
assertArrayEquals(new int[] {0, 1, 2, 3}, sliceIndexes);
}
@Test
public void setIndexes() {
AbstractSparseBinaryMatrix matrix = getTestMatrix();
boolean isTest = false;
matrix.set(new int[] { 0, 1 }, new int[] { 33, 44 }, isTest);
assertEquals(33, (int)matrix.get(0));
assertEquals(44, (int)matrix.get(1));
matrix = getTestMatrix();
isTest = true;
matrix.set(new int[] { 0, 1 }, new int[] { 1, 1 }, isTest);
assertEquals(1, (int)matrix.get(0));
assertEquals(1, (int)matrix.get(1));
}
@Test
public void clearIndexes() {
AbstractSparseBinaryMatrix matrix = getTestMatrix();
boolean isTest = false;
matrix.set(new int[] { 0, 1 }, new int[] { 1, 1 }, isTest);
assertEquals(1, matrix.getTrueCount(0));
matrix.clearStatistics(0);
assertEquals(0, matrix.getTrueCount(0));
}
@Test
public void testOr() {
AbstractSparseBinaryMatrix matrix2 = getTestMatrix();
boolean isTest = true;
matrix2.set(new int[] { 1 }, new int[] { 1 }, isTest);
AbstractSparseBinaryMatrix matrix = getTestMatrix();
assertEquals(0, matrix.getTrueCount(1));
assertEquals(0, matrix.getSparseIndices().length);
matrix.or(matrix2);
assertEquals(1, matrix.getTrueCount(1));
assertEquals(7, matrix.getSparseIndices().length);
// Now for trove collection
matrix = getTestMatrix();
assertEquals(0, matrix.getTrueCount(1));
assertEquals(0, matrix.getSparseIndices().length);
TIntArrayList tl = new TIntArrayList();
tl.add(1);
matrix.or(tl);
assertEquals(1, matrix.getTrueCount(1));
assertEquals(7, matrix.getSparseIndices().length);
}
@Test
public void testAll() {
AbstractSparseBinaryMatrix matrix = getTestMatrix();
AbstractSparseBinaryMatrix matrix2 = getTestMatrix();
assertTrue(matrix.all(matrix2));
boolean isTest = true;
matrix2.set(new int[] { 0, 1 }, new int[] { 1, 1 }, isTest);
assertFalse(matrix.all(matrix2));
// Now with trove
matrix = getTestMatrix();
matrix2 = getTestMatrix();
assertTrue(matrix.all(matrix2));
matrix2.set(new int[] { 0, 1 }, new int[] { 1, 1 }, isTest);
TIntArrayList tl = new TIntArrayList();
tl.add(1);
assertFalse(matrix.all(tl));
}
@Test
public void testAny() {
AbstractSparseBinaryMatrix matrix = getTestMatrix();
AbstractSparseBinaryMatrix matrix2 = getTestMatrix();
assertFalse(matrix.any(matrix2));
boolean isTest = true;
matrix2.set(new int[] { 0, 1 }, new int[] { 1, 1 }, isTest);
assertFalse(matrix.any(matrix2));
// Now with trove
matrix = getTestMatrix();
matrix2 = getTestMatrix();
assertFalse(matrix.any(matrix2));
matrix2.set(new int[] { 0, 1 }, new int[] { 1, 1 }, isTest);
TIntArrayList tl = new TIntArrayList();
tl.add(1);
assertFalse(matrix.any(tl));
assertTrue(matrix2.any(tl));
int[] onBits = { 0 };
assertFalse(matrix.any(onBits));
assertTrue(matrix2.any(onBits));
}
@Test
public void testEquals() {
AbstractSparseBinaryMatrix matrix = getTestMatrix();
AbstractSparseBinaryMatrix matrix2 = getTestMatrix();
assertTrue(matrix.equals(matrix));
assertFalse(matrix.equals(new Object()));
boolean isTest = false;
matrix2.set(new int[] { 0, 1 }, new int[] { 1, 1 }, isTest);
assertFalse(matrix.equals(matrix2));
}
}