/* * Apache License * Version 2.0, January 2004 * http://www.apache.org/licenses/ * * Copyright 2013 Aurelian Tutuianu * Copyright 2014 Aurelian Tutuianu * Copyright 2015 Aurelian Tutuianu * Copyright 2016 Aurelian Tutuianu * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package rapaio.math.linear; import rapaio.math.MathTools; import rapaio.core.stat.Mean; import rapaio.core.stat.Variance; import rapaio.data.Numeric; import rapaio.math.linear.dense.*; import rapaio.printer.Printable; import rapaio.sys.WS; import java.io.Serializable; import java.util.Arrays; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.DoubleStream; /** * Real matrix * <p> * Created by <a href="mailto:padreati@yahoo.com">Aurelian Tutuianu</a> on 2/3/16. */ public interface RM extends Serializable, Printable { int rowCount(); int colCount(); double get(int i, int j); void set(int i, int j, double value); void increment(int i, int j, double value); RV mapCol(int i); RV mapRow(int i); default RM mapRows(int... indexes) { return new MappedRM(this, true, indexes); } default RM rangeRows(int start, int end) { int[] rows = new int[end - start]; for (int i = start; i < end; i++) { rows[i - start] = i; } return new MappedRM(this, true, rows); } /** * Builds a new matrix having all columns and all the rows not specified by given indexes * * @param indexes rows which will be removed * @return new mapped matrix containing all rows not specified by indexes */ default RM removeRows(int... indexes) { Set<Integer> rem = Arrays.stream(indexes).boxed().collect(Collectors.toSet()); int[] rows = new int[rowCount() - rem.size()]; int pos = 0; for (int i = 0; i < rowCount(); i++) { if (rem.contains(i)) continue; rows[pos++] = i; } return new MappedRM(this, true, rows); } default RM mapCols(int... indexes) { return new MappedRM(this, false, indexes); } default RM rangeCols(int start, int end) { int[] cols = new int[end - start]; for (int i = start; i < end; i++) { cols[i - start] = i; } return new MappedRM(this, false, cols); } default RM removeCols(int... indexes) { Set<Integer> rem = Arrays.stream(indexes).boxed().collect(Collectors.toSet()); int[] cols = new int[colCount() - rem.size()]; int pos = 0; for (int i = 0; i < colCount(); i++) { if (rem.contains(i)) continue; cols[pos++] = i; } return new MappedRM(this, false, cols); } /** * @return new transposed matrix */ RM t(); default RM dot(RM B) { return MatrixMultiplication.ikjAlgorithm(this, B); } default RM dot(double x) { for (int i = 0; i < rowCount(); i++) { for (int j = 0; j < colCount(); j++) { set(i, j, get(i, j) * x); } } return this; } default RM plus(double x) { for (int i = 0; i < rowCount(); i++) { for (int j = 0; j < colCount(); j++) { increment(i, j, x); } } return this; } default RM plus(RM B) { if ((rowCount() != B.rowCount()) || (colCount() != B.colCount())) throw new IllegalArgumentException(String.format( "Matrices are not conform for addition: [%d x %d] + [%d x %d]", rowCount(), colCount(), B.rowCount(), B.colCount())); for (int i = 0; i < rowCount(); i++) { for (int j = 0; j < colCount(); j++) { increment(i, j, B.get(i, j)); } } return this; } default RM minus(double x) { return plus(-x); } default RM minus(RM B) { if ((rowCount() != B.rowCount()) || (colCount() != B.colCount())) throw new IllegalArgumentException(String.format( "Matrices are not conform for substraction: [%d x %d] + [%d x %d]", rowCount(), colCount(), B.rowCount(), B.colCount())); for (int i = 0; i < rowCount(); i++) { for (int j = 0; j < colCount(); j++) { increment(i, j, -B.get(i, j)); } } return this; } /** * Matrix rank * * @return effective numerical rank, obtained from SVD. */ default int rank() { return new SVDecomposition(this).rank(); } default Mean mean() { Numeric values = Numeric.empty(); for (int i = 0; i < rowCount(); i++) { for (int j = 0; j < colCount(); j++) { values.addValue(get(i, j)); } } return Mean.from(values); } default Variance var() { Numeric values = Numeric.empty(); for (int i = 0; i < rowCount(); i++) { for (int j = 0; j < colCount(); j++) { values.addValue(get(i, j)); } } return Variance.from(values); } /** * Diagonal vector of values */ default RV diag() { RV rv = SolidRV.empty(rowCount()); for (int i = 0; i < rowCount(); i++) { rv.set(i, get(i, i)); } return rv; } default RM scatter() { RM scatter = SolidRM.empty(colCount(), colCount()); RV mean = SolidRV.empty(colCount()); for (int i = 0; i < colCount(); i++) { mean.set(i, mapCol(i).mean().value()); } for (int i = 0; i < rowCount(); i++) { RM row = mapRow(i).asMatrix(); row.minus(mean.asMatrix()); scatter.plus(row.dot(row.t())); } return scatter; } /////////////////////// // other tools /////////////////////// /** * Does not override equals since this is a costly * algorithm and can slow down processing as a side effect. * * @param RM given matrix * @return true if dimension and elements are equal */ default boolean isEqual(RM RM) { return isEqual(RM, 1e-12); } default boolean isEqual(RM RM, double tol) { if (rowCount() != RM.rowCount()) return false; if (colCount() != RM.colCount()) return false; for (int i = 0; i < rowCount(); i++) { for (int j = 0; j < colCount(); j++) { if (!MathTools.eq(get(i, j), RM.get(i, j), tol)) return false; } } return true; } DoubleStream valueStream(); RM solidCopy(); default String summary() { StringBuilder sb = new StringBuilder(); String[][] m = new String[rowCount()][colCount()]; int max = 1; for (int i = 0; i < rowCount(); i++) { for (int j = 0; j < colCount(); j++) { m[i][j] = WS.formatShort(get(i, j)); max = Math.max(max, m[i][j].length() + 1); } } max = Math.max(max, String.format("[,%d]", rowCount()).length()); max = Math.max(max, String.format("[%d,]", colCount()).length()); int hCount = (int) Math.floor(WS.getPrinter().textWidth() / (double) max); int vCount = Math.min(rowCount() + 1, 101); int hLast = 0; while (true) { // take vertical stripes if (hLast >= colCount()) break; int hStart = hLast; int hEnd = Math.min(hLast + hCount, colCount()); int vLast = 0; while (true) { // print rows if (vLast >= rowCount()) break; int vStart = vLast; int vEnd = Math.min(vLast + vCount, rowCount()); for (int i = vStart; i <= vEnd; i++) { for (int j = hStart; j <= hEnd; j++) { if (i == vStart && j == hStart) { sb.append(String.format("%" + (max) + "s| ", "")); continue; } if (i == vStart) { sb.append(String.format("%" + Math.max(1, max - 1) + "d|", j - 1)); continue; } if (j == hStart) { sb.append(String.format("%" + Math.max(1, max - 1) + "d |", i - 1)); continue; } sb.append(String.format("%" + max + "s", m[i - 1][j - 1])); } sb.append("\n"); } vLast = vEnd; } hLast = hEnd; } return sb.toString(); } }