/* * JFlow * Created by Tim De Pauw <http://pwnt.be/> * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ package be.pwnt.jflow; import java.util.List; import java.util.Vector; public class Matrix { private List<List<Double>> values; public Matrix(int rows, int columns, double... values) { if (values.length > 0 && values.length != rows * columns) { throw new IllegalArgumentException(); } this.values = new Vector<List<Double>>(rows); for (int i = 0; i < rows; i++) { List<Double> row = new Vector<Double>(columns); for (int j = 0; j < columns; j++) { int index = i * columns + j; row.add(values.length > 0 ? values[index] : 0.0); } this.values.add(row); } } public Matrix(Matrix matrix) { this.values = new Vector<List<Double>>(matrix.values.size()); for (List<Double> row : matrix.values) { List<Double> r = new Vector<Double>(row.size()); for (double value : row) { r.add(value); } this.values.add(r); } } public int getRowCount() { return values.size(); } public int getColumnCount() { return values.get(0).size(); } public double getValue(int row, int column) { return values.get(row).get(column); } public void setValue(int row, int column, double value) { values.get(row).set(column, value); } public Matrix add(Matrix other) { if (getRowCount() != other.getRowCount() || getColumnCount() != other.getColumnCount()) { throw new IllegalArgumentException(); } Matrix p = new Matrix(getRowCount(), getColumnCount()); for (int r = 0; r < p.getRowCount(); r++) { for (int c = 0; c < p.getColumnCount(); c++) { p.setValue(r, c, getValue(r, c) + other.getValue(r, c)); } } return p; } public Matrix subtract(Matrix other) { if (getRowCount() != other.getRowCount() || getColumnCount() != other.getColumnCount()) { throw new IllegalArgumentException(); } Matrix p = new Matrix(getRowCount(), getColumnCount()); for (int r = 0; r < p.getRowCount(); r++) { for (int c = 0; c < p.getColumnCount(); c++) { p.setValue(r, c, getValue(r, c) - other.getValue(r, c)); } } return p; } public Matrix multiply(Matrix other) { if (getColumnCount() != other.getRowCount()) { throw new IllegalArgumentException(); } Matrix p = new Matrix(getRowCount(), other.getColumnCount()); for (int r = 0; r < p.getRowCount(); r++) { for (int c = 0; c < p.getColumnCount(); c++) { double sum = 0; for (int i = 0; i < getColumnCount(); i++) { sum += getValue(r, i) * other.getValue(i, c); } p.setValue(r, c, sum); } } return p; } }