/*
* Apache License
* Version 2.0, January 2004
* http://www.apache.org/licenses/
*
* Copyright 2013 Aurelian Tutuianu
* Copyright 2014 Aurelian Tutuianu
* Copyright 2015 Aurelian Tutuianu
* Copyright 2016 Aurelian Tutuianu
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rapaio.math.linear;
import org.junit.Test;
import rapaio.core.distributions.Normal;
import rapaio.core.distributions.Uniform;
import rapaio.math.linear.dense.MatrixMultiplication;
import rapaio.math.linear.dense.SolidRM;
import rapaio.util.Util;
import java.util.stream.IntStream;
public class MatrixMultiplicationTest {
@Test
public void basicTestMM() {
RM A = SolidRM.copy(3, 4,
2.3, 1.2, 1, 7,
19, 0, -1, 2,
2, 3, 4, 5
);
RM B = SolidRM.copy(4, 5,
1, 2, 3, 4, 5,
1.1, 12, 23, 4, 15,
1.2, 2.2, 23, 4, 5,
1.3, 2.3, 3, 14, 25
);
// A.printSummary();
// B.printSummary();
A.dot(B).printSummary();
// MatrixMultiplication.strassen(A, B).printSummary();
}
@Test
public void largeMatrices() {
int N = 1_000;
double p = 0.7;
RM A = SolidRM.empty(N, N);
RM B = SolidRM.empty(N, N);
Normal norm = new Normal(1, 12);
Uniform unif = new Uniform(0, 1);
for (int i = 0; i < A.rowCount(); i++) {
for (int j = 0; j < A.colCount(); j++) {
if (unif.sampleNext() > p)
A.set(i, j, norm.sampleNext());
}
}
for (int i = 0; i < B.rowCount(); i++) {
for (int j = 0; j < B.colCount(); j++) {
if (unif.sampleNext() > p)
B.set(i, j, norm.sampleNext());
}
}
int[] range = IntStream.range(N - 100, N).toArray();
// Util.measure(() -> MatrixMultiplication.ijkAlgorithm(A,B).mapRows(range).mapCols(range).printSummary());
// Util.measure(() -> MatrixMultiplication.ikjAlgorithm(A, B).mapRows(range).mapCols(range).printSummary());
// Util.measure(() -> MatrixMultiplication.tiledAlgorithm(A, B).mapRows(range).mapCols(range).printSummary());
// RM C1 = Util.measure(() -> MatrixMultiplication.ijkParallel(A, B).mapRows(range).mapCols(range));
RM C2 = Util.measure(() -> MatrixMultiplication.ikjParallel(A, B).mapRows(range).mapCols(range));
// Assert.assertTrue(C1.isEqual(C2));
}
}