/* * Copyright 2008 Network Engine for Objects in Lund AB [neotechnology.com] * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as * published by the Free Software Foundation, either version 3 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ package org.neo4j.graphalgo; import java.util.HashMap; import java.util.Map; import java.util.Set; /** * Utility class that hold implementations of vectors and matrices of doubles, * all indexed by integers, together with some essential operations on those. * @author Patrik Larsson */ public class MatrixUtil { /** * Vector of doubles */ public static class DoubleVector { Map<Integer,Double> values = new HashMap<Integer,Double>(); /** * Increment a value in the vector. * @param index * @param increment */ public void incrementValue( Integer index, double increment ) { Double currentValue = values.get( index ); if ( currentValue == null ) { currentValue = 0.0; } currentValue += increment; values.put( index, currentValue ); } /** * Set a value for a certain index. * @param index * @param value */ public void set( Integer index, double value ) { values.put( index, value ); } /** * Get a value for a certain index. * @param index * @return The value or null. */ public Double get( Integer index ) { return values.get( index ); } /** * Get all indices for which values are stored. * @return The indices as a set. */ public Set<Integer> getIndices() { return values.keySet(); } @Override public String toString() { String res = ""; int maxIndex = 0; for ( Integer i : values.keySet() ) { if ( i > maxIndex ) { maxIndex = i; } } for ( int i = 0; i <= maxIndex; ++i ) { Double value = values.get( i ); if ( value == null ) { value = 0.0; } res += " " + value; } return res + "\n"; } } /** * 2-Dimensional matrix of doubles. */ public static class DoubleMatrix { Map<Integer,DoubleVector> rows = new HashMap<Integer,DoubleVector>(); /** * Increment a value at a certain position. * @param rowIndex * @param columnIndex * @param increment */ public void incrementValue( Integer rowIndex, Integer columnIndex, double increment ) { DoubleVector row = rows.get( rowIndex ); if ( row == null ) { row = new DoubleVector(); rows.put( rowIndex, row ); } row.incrementValue( columnIndex, increment ); } /** * Set a value at a certain position. * @param rowIndex * @param columnIndex * @param value */ public void set( Integer rowIndex, Integer columnIndex, double value ) { DoubleVector row = rows.get( rowIndex ); if ( row == null ) { row = new DoubleVector(); rows.put( rowIndex, row ); } row.set( columnIndex, value ); } /** * Get the value at a certain position. * @param rowIndex * @param columnIndex * @return The value or null. */ public Double get( Integer rowIndex, Integer columnIndex ) { DoubleVector row = rows.get( rowIndex ); if ( row == null ) { return null; } return row.get( columnIndex ); } /** * Gets an entire row as a vector. * @param rowIndex * @return The row vector or null. */ public DoubleVector getRow( Integer rowIndex ) { return rows.get( rowIndex ); } /** * Inserts or replaces an entire row as a vector. * @param rowIndex * @param row */ public void setRow( Integer rowIndex, DoubleVector row ) { rows.put( rowIndex, row ); } @Override public String toString() { String res = ""; for ( Integer i : rows.keySet() ) { res += rows.get( i ).toString(); } return res; } /** * @return The number of rows in the matrix. */ public int size() { return rows.keySet().size(); } } /** * Destructive (in-place) LU-decomposition * @param matrix * input */ // TODO: extend to LUP? public static void LUDecomposition( DoubleMatrix matrix ) { int matrixSize = matrix.size(); for ( int i = 0; i < matrixSize - 1; ++i ) { double pivot = matrix.get( i, i ); DoubleVector row = matrix.getRow( i ); for ( int r = i + 1; r < matrixSize; ++r ) { Double rowStartValue = matrix.get( r, i ); if ( rowStartValue == null || rowStartValue == 0.0 ) { continue; } double factor = rowStartValue / pivot; matrix.set( r, i, factor ); for ( Integer c : row.values.keySet() ) { if ( c <= i ) { continue; } matrix.incrementValue( r, c, -row.get( c ) * factor ); } } } } /** * Solves the linear equation system ax = b. * @param a * Input matrix. Will be altered in-place. * @param b * Input vector. Will be altered in-place. * @return the vector x solving the equations. */ public static DoubleVector LinearSolve( DoubleMatrix a, DoubleVector b ) { LUDecomposition( a ); // first solve Ly = b ... for ( int r = 0; r < a.size(); ++r ) { DoubleVector row = a.getRow( r ); for ( Integer c : row.values.keySet() ) { if ( c >= r ) { continue; } b.incrementValue( r, -row.get( c ) * b.get( c ) ); } } // ... then Ux = y for ( int r = a.size() - 1; r >= 0; --r ) { DoubleVector row = a.getRow( r ); for ( Integer c : row.values.keySet() ) { if ( c <= r ) { continue; } b.incrementValue( r, -row.get( c ) * b.get( c ) ); } b.set( r, b.get( r ) / row.get( r ) ); } return b; } /** * Multiplies a matrix and a vector. * @param matrix * @param vector * @return The result as a new vector. */ public static DoubleVector multiply( DoubleMatrix matrix, DoubleVector vector ) { DoubleVector result = new DoubleVector(); for ( int rowIndex = 0; rowIndex < matrix.size(); ++rowIndex ) { DoubleVector row = matrix.getRow( rowIndex ); for ( Integer index : row.getIndices() ) { result.incrementValue( rowIndex, row.get( index ) * vector.get( index ) ); } } return result; } /** * In-place normalization of a vector. * @param vector * @return The initial euclidean length of the vector. */ public static double normalize( DoubleVector vector ) { double len = 0; for ( Integer index : vector.getIndices() ) { Double d = vector.get( index ); len += d * d; } len = Math.sqrt( len ); for ( Integer index : vector.getIndices() ) { vector.set( index, vector.get( index ) / len ); } return len; } }