package mikera.matrixx.impl; import mikera.matrixx.AMatrix; import mikera.vectorz.AVector; import mikera.vectorz.Vector; import mikera.vectorz.Vectorz; import mikera.vectorz.util.ErrorMessages; import mikera.vectorz.util.VectorzException; /** * Sparse banded matrix implementation. * * Composed of a list of diagonal bands. * * @author Mike * */ public class BandedMatrix extends ABandedMatrix { private static final long serialVersionUID = -4014060138907872914L; private final int minBand; private final int maxBand; private final AVector[] bands; private final int rowCount; private final int columnCount; private BandedMatrix(int rc, int cc, int minBand, AVector[] bands) { this.rowCount=rc; this.columnCount=cc; this.bands=bands; this.minBand=minBand; this.maxBand=minBand+bands.length-1; } public static BandedMatrix create(AMatrix m) { int rowCount=m.rowCount(); int columnCount=m.columnCount(); int minBand=-m.lowerBandwidth(); int maxBand=m.upperBandwidth(); AVector[] bands=new AVector[maxBand-minBand+1]; for (int i=minBand; i<=maxBand; i++) { bands[i-minBand]=m.getBand(i).clone(); } return new BandedMatrix(rowCount,columnCount,minBand,bands); } public static BandedMatrix create(int rowCount, int columnCount, int minBand, int maxBand) { if (-minBand>=rowCount) minBand=-(rowCount-1); if (maxBand>=columnCount) maxBand=columnCount-1; AVector[] bands=new AVector[maxBand-minBand+1]; for (int i=minBand; i<=maxBand; i++) { bands[i-minBand]=Vector.createLength(bandLength(rowCount,columnCount,i)); } return new BandedMatrix(rowCount,columnCount,minBand,bands); } public static BandedMatrix wrap(int rowCount, int columnCount, int minBand, int maxBand, AVector... bands) { if (bands.length!=(maxBand-minBand+1)) throw new IllegalArgumentException("Wrong number of bands: "+bands.length); for (int i=minBand; i<=maxBand; i++) { AVector b=bands[i-minBand]; if (b.length()!=bandLength(rowCount,columnCount,i)) { throw new IllegalArgumentException("Incorrect length of band "+ i +", was given: "+b.length()); } } return new BandedMatrix(rowCount,columnCount,minBand,bands); } @Override public int upperBandwidthLimit() { return maxBand; } @Override public int lowerBandwidthLimit() { return -minBand; } @Override public BandedMatrix getTranspose() { AVector[] nbands=new AVector[bands.length]; for (int i=0; i<(-minBand+maxBand+1); i++) { nbands[i]=bands[bands.length-1-i]; } return BandedMatrix.wrap(columnCount(), rowCount(), -maxBand, -minBand,nbands); } @Override public AVector getBand(int band) { if ((band>=minBand)&&(band<=maxBand)) return bands[band-minBand]; if ((band>=-rowCount)&&(band<=columnCount)) return Vectorz.createZeroVector(bandLength(band)); throw new IndexOutOfBoundsException(ErrorMessages.invalidBand(this, band)); } @Override public int rowCount() { return rowCount; } @Override public int columnCount() { return columnCount; } @Override public boolean isView() { return true; } @Override public double get(int i, int j) { checkIndex(i,j); return unsafeGet(i,j); } @Override public void set(int i, int j, double value) { getBand(bandIndex(i,j)).set(bandPosition(i,j),value); } @Override public double unsafeGet(int i, int j) { return getBand(bandIndex(i,j)).unsafeGet(bandPosition(i,j)); } @Override public void unsafeSet(int i, int j, double value) { getBand(bandIndex(i,j)).unsafeSet(bandPosition(i,j),value); } public void addAt(int i, int j, double d) { int band=j-i; AVector b=getBand(band); b.addAt(bandIndex(i,j), d); } @Override public BandedMatrix exactClone() { BandedMatrix b=new BandedMatrix(rowCount,columnCount,minBand,bands.clone()); for (int i=minBand; i<=maxBand; i++) { b.bands[i-minBand]=b.bands[i-minBand].exactClone(); } return b; } @Override public void transform(AVector source, AVector dest) { // fast transform is possible! if (!(dest instanceof Vector)) { super.transform(source, dest); } else if ((source instanceof Vector )) { transform ((Vector)source, (Vector)dest); } else { Vector t=(Vector)dest; t.fill(0.0); double[] data=t.getArray(); for (int i=minBand; i<=maxBand; i++) { AVector b=getBand(i); b.addProductToArray(1.0, 0, source, Math.max(i, 0), data, Math.max(-i, 0), bandLength(i)); } } } @Override public void transform(Vector source, Vector dest) { // fast transform is possible! Vector t=dest; t.fill(0.0); double[] data=dest.getArray(); for (int i=minBand; i<=maxBand; i++) { AVector b=getBand(i); b.addProductToArray(1.0, 0, source, Math.max(i, 0), data, Math.max(-i, 0), bandLength(i)); } } @Override public void validate() { super.validate(); if (minBand!=-lowerBandwidthLimit()) throw new VectorzException("Mismatched lower limit: "+minBand); if (maxBand!=upperBandwidthLimit()) throw new VectorzException("Mismatched upper limit: "+maxBand); } }