package mikera.matrixx; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import java.util.Random; import mikera.arrayz.Arrayz; import mikera.arrayz.INDArray; import mikera.arrayz.NDArray; import mikera.arrayz.TestArrays; import mikera.indexz.Index; import mikera.indexz.Indexz; import mikera.matrixx.impl.BandedMatrix; import mikera.matrixx.impl.BlockDiagonalMatrix; import mikera.matrixx.impl.ColumnMatrix; import mikera.matrixx.impl.IdentityMatrix; import mikera.matrixx.impl.ImmutableMatrix; import mikera.matrixx.impl.LowerTriangularMatrix; import mikera.matrixx.impl.PermutationMatrix; import mikera.matrixx.impl.PermutedMatrix; import mikera.matrixx.impl.QuadtreeMatrix; import mikera.matrixx.impl.RowMatrix; import mikera.matrixx.impl.ScalarMatrix; import mikera.matrixx.impl.SparseColumnMatrix; import mikera.matrixx.impl.SparseRowMatrix; import mikera.matrixx.impl.StridedMatrix; import mikera.matrixx.impl.SubsetMatrix; import mikera.matrixx.impl.UpperTriangularMatrix; import mikera.matrixx.impl.VectorMatrixM3; import mikera.matrixx.impl.VectorMatrixMN; import mikera.matrixx.impl.ZeroMatrix; import mikera.transformz.MatrixTransform; import mikera.transformz.TestTransformz; import mikera.vectorz.AVector; import mikera.vectorz.Vector; import mikera.vectorz.Vector3; import mikera.vectorz.Vectorz; import mikera.vectorz.impl.AxisVector; import mikera.vectorz.ops.Constant; import org.junit.Test; public class TestMatrices { private void doMutationTest(AMatrix m) { if (!m.isFullyMutable()) return; m=m.exactClone(); AMatrix m2=m.exactClone(); assertEquals(m,m2); int rc=m.rowCount(); int cc=m.columnCount(); for (int i=0; i<rc; i++) { for (int j=0; j<cc; j++) { m2.set(i,j,m2.get(i,j)+1.3); assertEquals(m2.get(i,j),m2.getRow(i).get(j),0.0); assertNotSame(m.get(i,j),m2.get(i, j)); } } } private void doTransposeTest(AMatrix m) { AMatrix m2=m.clone(); m2=m2.getTranspose(); assertTrue(m.equalsTranspose(m2)); assertTrue(m2.equalsTranspose(m)); assertEquals(m2, m.getTranspose()); assertEquals(m2, m.getTransposeView()); assertEquals(m2, m.toMatrixTranspose()); m2=m2.getTranspose(); assertEquals(m,m2); assertEquals(m.getTranspose().innerProduct(m),m.transposeInnerProduct(m)); } private void doSquareTransposeTest(AMatrix m) { AMatrix m2=m.clone(); m2.transposeInPlace(); // two different kinds of transpose should produce same result AMatrix tm=m.getTranspose(); assertEquals(tm,m2); assertEquals(m.trace(),tm.trace(),0.0); m2.transposeInPlace(); assertEquals(m,m2); } private void doNotSquareTests(AMatrix m) { try { m.getLeadingDiagonal(); fail(); } catch (Throwable t) { // OK } try { m.checkSquare(); fail(); } catch (Throwable t) { // OK } try { m.inverse(); fail(); } catch (Throwable t) { // OK } try { m.determinant(); fail(); } catch (Throwable t) { // OK } } private void doLeadingDiagonalTests(AMatrix m) { int dims=m.rowCount(); assertEquals(dims,m.columnCount()); AVector v=m.getLeadingDiagonal(); for (int i=0; i<dims; i++) { assertEquals(v.get(i),m.get(i, i),0.0); } } private void doTraceTests(AMatrix m) { assertEquals(m.clone().trace(), m.trace(),0.00001); } private void doMaybeSquareTests(AMatrix m) { if (!m.isSquare()) { assertNotEquals(m.rowCount(),m.columnCount()); doNotSquareTests(m); } else { assertEquals(m.rowCount(),m.columnCount()); assertEquals(m.rowCount(),m.checkSquare()); doSquareTransposeTest(m); doTraceTests(m); doLeadingDiagonalTests(m); } } private void doSwapTest(AMatrix m) { if ((m.rowCount()<2)||(m.columnCount()<2)) return; m=m.clone(); AMatrix m2=m.clone(); m2.swapRows(0, 1); assert(!m2.equals(m)); m2.swapRows(0, 1); assert(m2.equals(m)); m2.swapColumns(0, 1); assert(!m2.equals(m)); m2.swapColumns(0, 1); assert(m2.equals(m)); } void doRandomTests(AMatrix m) { m=m.clone(); Matrixx.fillRandomValues(m); doSwapTest(m); doMutationTest(m); } private void doAddTest(AMatrix m) { if (!m.isFullyMutable()) return; AMatrix m2=m.exactClone(); AMatrix m3=m.exactClone(); m2.add(m); m2.add(m); m3.addMultiple(m, 2.0); assertTrue(m2.epsilonEquals(m3)); } private void doCloneSafeTest(AMatrix m) { if ((m.rowCount()==0)||(m.columnCount()==0)) return; AMatrix m2=m.clone(); m2.set(0,0,Math.PI); assertNotSame(m.get(0,0),m2.get(0,0)); } private void doSubMatrixTest(AMatrix m) { int rc=m.rowCount(); int cc=m.columnCount(); if ((rc<=1)||(cc<=1)) return; AMatrix sm=m.subMatrix(1, rc-1, 1, cc-1); assertEquals(rc-1,sm.rowCount()); assertEquals(cc-1,sm.columnCount()); for (int i=1; i<rc; i++) { for (int j=1; j<cc; j++) { assertEquals(m.get(i,j),sm.get(i-1,j-1),0.0); } } } private void doBoundsTest(AMatrix m) { int rc=m.rowCount(); int cc=m.columnCount(); try { m.get(-1,-1); fail(); } catch (IndexOutOfBoundsException a) {/* OK */} try { m.get(rc,cc); fail(); } catch (IndexOutOfBoundsException a) {/* OK */} try { m.get(0,-1); fail(); } catch (IndexOutOfBoundsException a) {/* OK */} try { m.get(0,cc); fail(); } catch (IndexOutOfBoundsException a) {/* OK */} try { m.get(-1,0); fail(); } catch (IndexOutOfBoundsException a) {/* OK */} try { m.get(rc,0); fail(); } catch (IndexOutOfBoundsException a) {/* OK */} if (m.isFullyMutable()) { m=m.exactClone(); try { m.set(-1,-1,1); fail(); } catch (IndexOutOfBoundsException a) {/* OK */} try { m.set(rc,cc,1); fail(); } catch (IndexOutOfBoundsException a) {/* OK */} } } private void doRowColumnTests(AMatrix m) { assertEquals(m.rowCount(),new MatrixTransform(m).outputDimensions()); assertEquals(m.columnCount(),new MatrixTransform(m).inputDimensions()); m=m.clone(); int rc=m.rowCount(); int cc=m.columnCount(); if ((rc==0)||(cc==0)) return; for (int i=0; i<rc; i++) { AVector row=m.getRow(i); assertEquals(row,m.cloneRow(i)); assertEquals(cc,row.length()); } for (int i=0; i<cc; i++) { AVector col=m.getColumn(i); assertEquals(rc,col.length()); } AVector row=m.getRowView(0); AVector col=m.getColumnView(0); row.set(0,1.77); assertEquals(1.77,m.get(0,0),0.0); col.set(0,0.23); assertEquals(0.23,m.get(0,0),0.0); AVector all=m.asVector(); assertEquals(m.rowCount()*m.columnCount(),all.length()); all.set(0,0.78); assertEquals(0.78,row.get(0),0.0); assertEquals(0.78,col.get(0),0.0); new TestArrays().testArray(row); new TestArrays().testArray(col); new TestArrays().testArray(all); } private void doBandTests(AMatrix m) { int rc=m.rowCount(); int cc=m.columnCount(); int bandMin=-m.rowCount(); int bandMax=m.columnCount(); Matrix mc=m.toMatrix(); // TODO: consider what to do about out-of-range bands? //assertNull(m.getBand(bandMin-1)); //assertNull(m.getBand(bandMax+1)); for (int i=bandMin; i<=bandMax; i++) { AVector b=m.getBand(i); assertEquals(b.length(),m.bandLength(i)); assertEquals(mc.getBand(i),b); // wrapped band test if (rc*cc!=0) assertEquals(Math.max(rc, cc),m.getBandWrapped(i).length()); } assertEquals(m,BandedMatrix.create(m)); } private void doVectorTest(AMatrix m) { m=m.clone(); AVector v=m.asVector(); assertEquals(v,m.toVector()); assertEquals(m.elementSum(),v.elementSum(),0.000001); AMatrix m2=Matrixx.createFromVector(v, m.rowCount(), m.columnCount()); assertEquals(m,m2); assertEquals(v,m2.asVector()); if (v.length()>0) { v.set(0,10.0); assertEquals(10.0,m.get(0,0),0.0); } } void doParseTest(AMatrix m) { if (m.rowCount()==0) return; assertEquals(m,Matrixx.parse(m.toString())); } void doBigComposeTest(AMatrix m) { AMatrix a=Matrixx.createRandomSquareMatrix(m.rowCount()); AMatrix b=Matrixx.createRandomSquareMatrix(m.columnCount()); AMatrix mb=m.compose(b); AMatrix amb=a.compose(mb); AVector v=Vectorz.createUniformRandomVector(b.columnCount()); AVector ambv=a.transform(m.transform(b.transform(v))); assertTrue(amb.transform(v).epsilonEquals(ambv)); } private void testApplyOp(AMatrix m) { if (!m.isFullyMutable()) return; AMatrix c=m.exactClone(); AMatrix d=m.exactClone(); c.asVector().fill(5.0); d.applyOp(Constant.create(5.0)); assertEquals(c,d); assertTrue(d.epsilonEquals(c)); } private void testExactClone(AMatrix m) { AMatrix c=m.exactClone(); AMatrix d=m.clone(); Matrix mc=m.toMatrix(); assertEquals(m,mc); assertEquals(m,c); assertEquals(m,d); } private void testSparseClone(AMatrix m) { AMatrix s=Matrixx.createSparse(m); assertEquals(m,s); AMatrix s2=Matrixx.createSparseRows(m); assertEquals(m,s2); } void doScaleTest(AMatrix m) { if(!m.isFullyMutable()) return; AMatrix m1=m.exactClone(); AMatrix m2=m.clone(); m1.scale(2.0); m2.add(m); assertTrue(m1.epsilonEquals(m2)); m1.scale(0.0); assertTrue(m1.isZero()); } private void doMulTest(AMatrix m) { AVector v=Vectorz.newVector(m.columnCount()); AVector t=Vectorz.newVector(m.rowCount()); m.transform(v, t); AVector t2=m.transform(v); assertEquals(t,t2); assertEquals(t,m.innerProduct(v)); } private void doNDArrayTest(AMatrix m) { NDArray a=NDArray.newArray(m.getShape()); a.set(m); int rc=m.rowCount(); int cc=m.columnCount(); for (int i=0; i<rc; i++) { assertEquals(m.getRow(i),a.slice(i)); } for (int i=0; i<cc; i++) { assertEquals(m.getColumn(i),a.slice(1,i)); } } private void doTriangularTests(AMatrix m) { boolean sym=m.isSymmetric(); boolean diag=m.isDiagonal(); boolean uppt=m.isUpperTriangular(); boolean lowt=m.isLowerTriangular(); if (diag) { assertTrue(sym); assertTrue(uppt); assertTrue(lowt); } if (sym) { assertTrue(m.isSquare()); assertEquals(m,m.getTranspose()); } if (sym&uppt&lowt) { assertTrue(diag); } if (uppt) { assertTrue(m.getTranspose().isLowerTriangular()); } if (lowt) { assertTrue(m.getTranspose().isUpperTriangular()); } } void doGenericTests(AMatrix m) { m.validate(); testApplyOp(m); testExactClone(m); testSparseClone(m); doTransposeTest(m); doTriangularTests(m); doVectorTest(m); doParseTest(m); doNDArrayTest(m); doScaleTest(m); doBoundsTest(m); doMulTest(m); doAddTest(m); doRowColumnTests(m); doBandTests(m); doCloneSafeTest(m); doMutationTest(m); doMaybeSquareTests(m); doRandomTests(m); doBigComposeTest(m); doSubMatrixTest(m); TestTransformz.doITransformTests(new MatrixTransform(m)); new TestArrays().testArray(m); } @Test public void g_ZeroMatrix() { // zero matrices doGenericTests(Matrixx.createImmutableZeroMatrix(3, 2)); doGenericTests(Matrixx.createImmutableZeroMatrix(5, 5)); doGenericTests(Matrixx.createImmutableZeroMatrix(3, 3).reorder(new int[] {2,0,1})); doGenericTests(Matrixx.createImmutableZeroMatrix(1, 7)); doGenericTests(Matrixx.createImmutableZeroMatrix(1, 0)); doGenericTests(Matrixx.createImmutableZeroMatrix(0, 1)); doGenericTests(Matrixx.createImmutableZeroMatrix(0, 0)); } @Test public void g_PrimitiveMatrix() { // specialised 3x3 matrix Matrix33 m33=new Matrix33(); randomise(m33); doGenericTests(m33); // specialised 2*2 matrix Matrix22 m22=new Matrix22(); randomise(m22); doGenericTests(m22); // specialised 1*1 matrix Matrix11 m11=new Matrix11(); randomise(m11); doGenericTests(m11); } @Test public void g_PermutedMatrix() { // general M*N matrix VectorMatrixMN mmn=new VectorMatrixMN(6 ,7); randomise(mmn); // permuted matrix PermutedMatrix pmm=new PermutedMatrix(mmn, Indexz.createRandomPermutation(mmn.rowCount()), Indexz.createRandomPermutation(mmn.columnCount())); doGenericTests(pmm); doGenericTests(pmm.subMatrix(1, 4, 1, 5)); } @Test public void g_VectorMatrix() { // specialised Mx3 matrix VectorMatrixM3 mm3=new VectorMatrixM3(10); randomise(mm3); doGenericTests(mm3); doGenericTests(mm3.subMatrix(1, 1, 1, 1)); // general M*N matrix VectorMatrixMN mmn=new VectorMatrixMN(6 ,7); randomise(mmn); doGenericTests(mmn); doGenericTests(mmn.subMatrix(1, 4, 1, 5)); // small 2*2 matrix mmn=new VectorMatrixMN(2,2); doGenericTests(mmn); // 1x0 matrix should work mmn=new VectorMatrixMN(1 ,0); doGenericTests(mmn); // square M*M matrix mmn=new VectorMatrixMN(6 ,6); doGenericTests(mmn); } private static long seed; private void randomise(INDArray m) { Arrayz.fillNormal(m, seed++); } @Test public void g_Matrix() { // general M*N matrix VectorMatrixMN mmn=new VectorMatrixMN(6 ,7); randomise(mmn); Matrix am1=new Matrix(Matrix33.createScaleMatrix(4.0)); doGenericTests(am1); Matrix am2=new Matrix(mmn); doGenericTests(am2); } @Test public void g_DenseColumnMatrix() { Matrix am1=new Matrix(Matrix33.createScaleMatrix(Math.PI)); doGenericTests(am1.getTranspose()); Matrix am2=new Matrix(new VectorMatrixMN(6 ,7)); randomise(am2); doGenericTests(am2.getTranspose()); } @Test public void g_SubsetMatrix() { doGenericTests(SubsetMatrix.create(Index.of(0,1,2),3)); doGenericTests(SubsetMatrix.create(Index.of(0,1,3,10),12)); doGenericTests(SubsetMatrix.create(Index.of(0,3,2,1),4)); } @Test public void g_ScalarMatrix() { doGenericTests(ScalarMatrix.create(1,3.0)); doGenericTests(ScalarMatrix.create(3,Math.E)); doGenericTests(ScalarMatrix.create(5,0.0)); doGenericTests(ScalarMatrix.create(5,2.0).subMatrix(1, 3, 1, 3)); } @Test public void g_RowMatrix() { doGenericTests(new RowMatrix(Vector.of(1,2,3,4,7))); doGenericTests(new RowMatrix(Vector3.of(1,2,3))); } @Test public void g_ColumnMatrix() { doGenericTests(new ColumnMatrix(Vector.of(1,2,3,4,5))); doGenericTests(new ColumnMatrix(Vector3.of(1,2,3))); } @Test public void g_StridedMatrix() { StridedMatrix strm=StridedMatrix.create(1, 1); doGenericTests(strm); strm=StridedMatrix.create(Matrixx.createRandomMatrix(3, 4)); doGenericTests(strm); strm=StridedMatrix.wrap(Matrix.create(Matrixx.createRandomMatrix(3, 3))); doGenericTests(strm); } @Test public void g_PermutationMatrix() { doGenericTests(PermutationMatrix.create(0,1,2)); doGenericTests(PermutationMatrix.create(4,2,3,1,0)); doGenericTests(PermutationMatrix.create(Indexz.createRandomPermutation(8))); doGenericTests(PermutationMatrix.create(Indexz.createRandomPermutation(6)).subMatrix(1,3,2,4)); } @Test public void g_BandedMatrix() { doGenericTests(BandedMatrix.create(3, 3, -2, 2)); doGenericTests(BandedMatrix.create(Matrixx.createRandomMatrix(2, 2))); doGenericTests(BandedMatrix.wrap(3, 4, 0, 0,Vector.of(1,2,3))); } @Test public void g_QuadtreeMatrix() { doGenericTests(QuadtreeMatrix.create(new Matrix22(1,0,0,2) ,ZeroMatrix.create(2, 1) ,ZeroMatrix.create(1, 2) ,Matrixx.createScaleMatrix(1, 3))); } @Test public void g_SparseMatrix() { doGenericTests(SparseRowMatrix.create(Vector.of(0,1,-Math.E),null,null,AxisVector.create(2, 3))); doGenericTests(SparseRowMatrix.create(Matrixx.createRandomSquareMatrix(3))); doGenericTests(SparseColumnMatrix.create(Vector.of(0,1,-Math.PI),null,null,AxisVector.create(2, 3))); doGenericTests(SparseColumnMatrix.create(Matrixx.createRandomSquareMatrix(4))); } @Test public void g_TriangularMatrix() { doGenericTests(UpperTriangularMatrix.createFrom(Matrixx.createRandomSquareMatrix(1))); doGenericTests(UpperTriangularMatrix.createFrom(Matrixx.createRandomSquareMatrix(4))); doGenericTests(UpperTriangularMatrix.createFrom(Matrixx.createRandomMatrix(4,3))); doGenericTests(UpperTriangularMatrix.createFrom(Matrixx.createRandomMatrix(2,3))); doGenericTests(LowerTriangularMatrix.createFrom(Matrixx.createRandomSquareMatrix(1))); doGenericTests(LowerTriangularMatrix.createFrom(Matrixx.createRandomSquareMatrix(4))); doGenericTests(LowerTriangularMatrix.createFrom(Matrixx.createRandomMatrix(4,3))); doGenericTests(LowerTriangularMatrix.createFrom(Matrixx.createRandomMatrix(2,3))); } @Test public void g_ImmutableMatrix() { doGenericTests(new ImmutableMatrix(Matrixx.createRandomMatrix(4, 5))); doGenericTests(new ImmutableMatrix(Matrixx.createRandomMatrix(3, 3))); } @Test public void g_BlockDiagonalMatrix() { doGenericTests(BlockDiagonalMatrix.create(IdentityMatrix.create(2),Matrixx.createRandomSquareMatrix(2))); } }