package mikera.matrixx.impl;
import java.util.Arrays;
import mikera.arrayz.ISparse;
import mikera.matrixx.AMatrix;
import mikera.vectorz.util.IntArrays;
/**
* Class representing a square, block diagonal matrix.
*
* Each block on the main diagonal must be a square matrix, but need not itself be diagonal.
*
* @author Mike
*
*/
public class BlockDiagonalMatrix extends ABlockMatrix implements ISparse {
private static final long serialVersionUID = -8569790012901451992L;
private final AMatrix[] mats;
private final int[] sizes;
private final int[] offsets;
private final int blockCount;
private final int size;
private BlockDiagonalMatrix(AMatrix[] newMats) {
blockCount=newMats.length;
mats=newMats;
sizes=new int[blockCount];
offsets=new int[blockCount+1];
int totalSize=0;
for (int i=0; i<blockCount; i++) {
int size=mats[i].rowCount();
sizes[i]=size;
offsets[i]=totalSize;
totalSize+=size;
}
this.size=totalSize;
offsets[blockCount]=size;
}
public static BlockDiagonalMatrix create(AMatrix... blocks) {
return new BlockDiagonalMatrix(blocks.clone());
}
@Override
public boolean isFullyMutable() {
return false;
}
@Override
public boolean isMutable() {
for (int i=0; i<blockCount; i++) {
if (mats[i].isMutable()) return true;
}
return true;
}
@Override
public AMatrix getBlock(int rowBlock, int colBlock) {
if (rowBlock!=colBlock) return ZeroMatrix.create(getBlockRowCount(rowBlock), getBlockColumnCount(colBlock));
return mats[rowBlock];
}
public int getBlockColumnStart(int colBlock) {
return offsets[colBlock];
}
public int getBlockRowStart(int rowBlock) {
return offsets[rowBlock];
}
@Override
public int getBlockColumnCount(int colBlock) {
return sizes[colBlock];
}
@Override
public int getBlockRowCount(int rowBlock) {
return sizes[rowBlock];
}
@Override
public int getColumnBlockIndex(int col) {
if ((col<0)||(col>=size)) throw new IndexOutOfBoundsException("Column: "+ col);
int i=IntArrays.indexLookup(offsets, col);
if (i<0) throw new IndexOutOfBoundsException("Column: "+ col);
return i;
}
@Override
public int getRowBlockIndex(int row) {
if ((row<0)||(row>=size)) throw new IndexOutOfBoundsException("Row: "+ row);
int i=IntArrays.indexLookup(offsets, row);
if (i<0) throw new IndexOutOfBoundsException("Row: "+ row);
return i;
}
@Override
public int rowCount() {
return size;
}
@Override
public int columnCount() {
return size;
}
@Override
public double get(int row, int column) {
int bi=getRowBlockIndex(row);
int bj=getColumnBlockIndex(column);
if (bi!=bj) return 0.0;
int i=row-offsets[bi];
int j=column-offsets[bi];
return mats[bi].unsafeGet(i, j);
}
@Override
public void set(int row, int column, double value) {
int bi=getRowBlockIndex(row);
int bj=getColumnBlockIndex(column);
if (bi!=bj) throw new UnsupportedOperationException("Block Diagonal Matrix immutable at this position");
int i=row-offsets[bi];
int j=column-offsets[bi];
mats[bi].unsafeSet(i, j, value);
}
@Override
public AMatrix exactClone() {
AMatrix[] newMats=mats.clone();
for (int i=0; i<blockCount; i++) {
newMats[i]=newMats[i].exactClone();
}
return new BlockDiagonalMatrix(newMats);
}
@Override
public int columnBlockCount() {
return blockCount;
}
@Override
public int rowBlockCount() {
return blockCount;
}
@Override
public void copyColumnTo(int col, double[] dest, int destOffset) {
int i=getColumnBlockIndex(col);
int si=offsets[i];
int di=offsets[i+1];
Arrays.fill(dest, destOffset, si+destOffset, 0.0);
mats[i].copyColumnTo(col-si, dest, destOffset+si);
Arrays.fill(dest, di+destOffset, size+destOffset, 0.0);
}
@Override
public void copyRowTo(int row, double[] dest, int destOffset) {
int i=getRowBlockIndex(row);
int si=offsets[i];
int di=offsets[i+1];
Arrays.fill(dest, destOffset, si+destOffset, 0.0);
mats[i].copyRowTo(row-si, dest, destOffset+si);
Arrays.fill(dest, di+destOffset, size+destOffset, 0.0);
}
@Override
public double density() {
long nzero=0;
for (int i=0; i<blockCount; i++) {
nzero+=mats[i].nonZeroCount();
}
return nzero/((double)elementCount());
}
@Override
public boolean hasUncountable() {
for(int i=0; i<blockCount; i++) {
if (mats[i].hasUncountable()) {
return true;
}
}
return false;
}
/**
* Returns the sum of all the elements raised to a specified power
* @return
*/
@Override
public double elementPowSum(double p) {
double result = 0;
for(int i=0; i<blockCount; i++) {
result += mats[i].elementPowSum(p);
}
return result;
}
/**
* Returns the sum of the absolute values of all the elements raised to a specified power
* @return
*/
@Override
public double elementAbsPowSum(double p) {
double result = 0;
for(int i=0; i<blockCount; i++) {
result += mats[i].elementAbsPowSum(p);
}
return result;
}
}