/* * Apache License * Version 2.0, January 2004 * http://www.apache.org/licenses/ * * Copyright 2013 Aurelian Tutuianu * Copyright 2014 Aurelian Tutuianu * Copyright 2015 Aurelian Tutuianu * Copyright 2016 Aurelian Tutuianu * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package rapaio.math.linear.dense; import rapaio.math.linear.RM; import java.util.stream.IntStream; /** * This class offers different algorithms for matrix multiplication. * * @author Martin Thoma */ public class MatrixMultiplication { static int LEAF_SIZE = 256; public static RM ijkAlgorithm(RM A, RM B) { // initialise C RM C = SolidRM.empty(A.rowCount(), B.colCount()); for (int i = 0; i < A.rowCount(); i++) { for (int j = 0; j < B.colCount(); j++) { for (int k = 0; k < A.colCount(); k++) { C.increment(i, j, A.get(i, k) * B.get(k, j)); } } } return C; } public static RM ijkParallel(RM A, RM B) { // initialise C RM C = SolidRM.empty(A.rowCount(), B.colCount()); IntStream.range(0, A.rowCount()).parallel().forEach(i -> { for (int j = 0; j < B.colCount(); j++) { for (int k = 0; k < A.colCount(); k++) { C.increment(i, j, A.get(i, k) * B.get(k, j)); } } }); return C; } public static RM ikjAlgorithm(RM A, RM B) { // initialise C RM C = SolidRM.empty(A.rowCount(), B.colCount()); for (int i = 0; i < A.rowCount(); i++) { for (int k = 0; k < A.colCount(); k++) { if (A.get(i, k) == 0) continue; for (int j = 0; j < B.colCount(); j++) { C.increment(i, j, A.get(i, k) * B.get(k, j)); } } } return C; } public static RM ikjParallel(RM A, RM B) { // initialise C RM C = SolidRM.empty(A.rowCount(), B.colCount()); IntStream.range(0, A.rowCount()).parallel().forEach(i -> { for (int k = 0; k < A.colCount(); k++) { if (A.get(i, k) == 0) continue; for (int j = 0; j < B.colCount(); j++) { C.increment(i, j, A.get(i, k) * B.get(k, j)); } } }); return C; } public static RM tiledAlgorithm(RM A, RM B) { RM C = SolidRM.empty(A.rowCount(), B.colCount()); // Pick a tile size T = Θ(√M) int T = 256; // For I from 1 to n in steps of T: for (int I = 0; I < A.rowCount(); I += T) { // For J from 1 to p in steps of T: for (int J = 0; J < B.colCount(); J += T) { // For K from 1 to m in steps of T: for (int K = 0; K < A.colCount(); K += T) { // Multiply AI:I+T, K:K+T and BK:K+T, J:J+T into CI:I+T, J:J+T, that is: // For i from I to min(I + T, n): for (int i = I; i < Math.min(I + T, A.rowCount()); i++) { // For j from J to min(J + T, p): for (int j = J; j < Math.min(J + T, B.colCount()); j++) { // Let sum = 0 double sum = 0; // For k from K to min(K + T, m): for (int k = K; k < Math.min(K + T, A.colCount()); k++) { // Set sum ← sum + Aik × Bkj sum += A.get(i, k) * B.get(k, j); } // Set Cij ← sum C.increment(i, j, sum); } } } } } return C; } private static RM add(RM A, RM B) { RM C = SolidRM.empty(A.rowCount(), A.colCount()); for (int i = 0; i < A.rowCount(); i++) { for (int j = 0; j < A.colCount(); j++) { C.set(i, j, A.get(i, j) + B.get(i, j)); } } return C; } private static RM subtract(RM A, RM B) { int n = A.rowCount(); RM C = SolidRM.empty(n, n); for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { C.set(i, j, A.get(i, j) - B.get(i, j)); } } return C; } private static int nextPowerOfTwo(int n) { int log2 = (int) Math.ceil(Math.log(n) / Math.log(2)); return (int) Math.pow(2, log2); } /* public static RM strassen(RM A, RM B) { // Make the matrices bigger so that you can apply the strassen // algorithm recursively without having to deal with odd // matrix sizes int n = Math.max(A.rowCount(), Math.max(A.colCount(), B.colCount())); int m = nextPowerOfTwo(n); RM APrep = SolidRM.empty(m, m); RM BPrep = SolidRM.empty(m, m); for (int i = 0; i < A.rowCount(); i++) { for (int j = 0; j < A.colCount(); j++) { APrep.set(i, j, A.get(i, j)); } } for (int i = 0; i < B.rowCount(); i++) { for (int j = 0; j < B.colCount(); j++) { BPrep.set(i, j, B.get(i, j)); } } RM CPrep = strassenR(APrep, BPrep); RM C = SolidRM.empty(A.rowCount(), B.colCount()); for (int i = 0; i < A.rowCount(); i++) { for (int j = 0; j < B.colCount(); j++) { C.set(i, j, CPrep.get(i, j)); } } return C; } private static RM strassenR(RM A, RM B) { int n = A.rowCount(); if (n <= LEAF_SIZE) { return ikjAlgorithm(A, B); } else { // initializing the new sub-matrices int newSize = n / 2; // RM a11 = Linear.newRMEmpty(newSize, newSize); // RM a12 = Linear.newRMEmpty(newSize, newSize); // RM a21 = Linear.newRMEmpty(newSize, newSize); // RM a22 = Linear.newRMEmpty(newSize, newSize); // // RM b11 = Linear.newRMEmpty(newSize, newSize); // RM b12 = Linear.newRMEmpty(newSize, newSize); // RM b21 = Linear.newRMEmpty(newSize, newSize); // RM b22 = Linear.newRMEmpty(newSize, newSize); // // RM aResult; // RM bResult; // dividing the matrices in 4 sub-matrices: // for (int i = 0; i < newSize; i++) { // for (int j = 0; j < newSize; j++) { // a11.set(i, j, A.get(i, j)); // top left // a12.set(i, j, A.get(i, j + newSize)); // top right // a21.set(i, j, A.get(i + newSize, j)); // bottom left // a22.set(i, j, A.get(i + newSize, j + newSize)); // bottom right // // b11.set(i, j, B.get(i, j)); // top left // b12.set(i, j, B.get(i, j + newSize)); // top right // b21.set(i, j, B.get(i + newSize, j)); // bottom left // b22.set(i, j, B.get(i + newSize, j + newSize)); // bottom right // } // } RM a11 = A.rangeRows(0, newSize).rangeCols(0, newSize); RM a12 = A.rangeRows(0, newSize).rangeCols(newSize, 2 * newSize); RM a21 = A.rangeRows(newSize, 2 * newSize).rangeCols(0, newSize); RM a22 = A.rangeRows(newSize, 2 * newSize).rangeCols(newSize, 2 * newSize); RM b11 = B.rangeRows(0, newSize).rangeCols(0, newSize); RM b12 = B.rangeRows(0, newSize).rangeCols(newSize, 2 * newSize); RM b21 = B.rangeRows(newSize, 2 * newSize).rangeCols(0, newSize); RM b22 = B.rangeRows(newSize, 2 * newSize).rangeCols(newSize, 2 * newSize); RM aResult; RM bResult; // Calculating p1 to p7: aResult = add(a11, a22); bResult = add(b11, b22); RM p1 = strassenR(aResult, bResult); // p1 = (a11+a22) * (b11+b22) aResult = add(a21, a22); // a21 + a22 RM p2 = strassenR(aResult, b11); // p2 = (a21+a22) * (b11) bResult = subtract(b12, b22); // b12 - b22 RM p3 = strassenR(a11, bResult); // p3 = (a11) * (b12 - b22) bResult = subtract(b21, b11); // b21 - b11 RM p4 = strassenR(a22, bResult); // p4 = (a22) * (b21 - b11) aResult = add(a11, a12); // a11 + a12 RM p5 = strassenR(aResult, b22); // p5 = (a11+a12) * (b22) aResult = subtract(a21, a11); // a21 - a11 bResult = add(b11, b12); // b11 + b12 RM p6 = strassenR(aResult, bResult); // p6 = (a21-a11) * (b11+b12) aResult = subtract(a12, a22); // a12 - a22 bResult = add(b21, b22); // b21 + b22 RM p7 = strassenR(aResult, bResult); // p7 = (a12-a22) * (b21+b22) // calculating c21, c21, c11 e c22: RM c12 = add(p3, p5); // c12 = p3 + p5 RM c21 = add(p2, p4); // c21 = p2 + p4 aResult = add(p1, p4); // p1 + p4 bResult = add(aResult, p7); // p1 + p4 + p7 RM c11 = subtract(bResult, p5); // c11 = p1 + p4 - p5 + p7 aResult = add(p1, p3); // p1 + p3 bResult = add(aResult, p6); // p1 + p3 + p6 RM c22 = subtract(bResult, p2); // c22 = p1 + p3 - p2 + p6 // Grouping the results obtained in a single matrix: RM C = RM.empty(n, n); for (int i = 0; i < newSize; i++) { for (int j = 0; j < newSize; j++) { C.set(i, j, c11.get(i, j)); C.set(i, j + newSize, c12.get(i, j)); C.set(i + newSize, j, c21.get(i, j)); C.set(i + newSize, j + newSize, c22.get(i, j)); } } return C; } } */ }