package mikera.matrixx.impl; import mikera.matrixx.AMatrix; import mikera.vectorz.AVector; import mikera.vectorz.Vector; import mikera.vectorz.impl.ADenseArrayVector; import mikera.vectorz.util.DoubleArrays; import mikera.vectorz.util.ErrorMessages; import mikera.vectorz.util.VectorzException; /** * Specialised diagonal matrix class, with dense double[] array storage for the leading diagonal only. * * Not fully mutable - only the diagonal values can be changed * * @author Mike */ public final class DiagonalMatrix extends ADiagonalMatrix { private static final long serialVersionUID = -6721785163444613243L; final double[] data; private final Vector lead; public DiagonalMatrix(int dimensions) { super(dimensions); data=new double[dimensions]; lead=Vector.wrap(data); } private DiagonalMatrix(double... values) { super(values.length); data=values; lead=Vector.wrap(data); } private DiagonalMatrix(Vector values) { super(values.length()); data=values.getArray(); lead=values; } public static DiagonalMatrix createDimensions(int dims) { return new DiagonalMatrix(dims); } public static DiagonalMatrix create(double... values) { int dimensions=values.length; double[] data=new double[dimensions]; System.arraycopy(values, 0, data, 0, dimensions); return new DiagonalMatrix(data); } public static DiagonalMatrix create(AVector v) { return wrap(v.toDoubleArray()); } public static DiagonalMatrix create(AMatrix m) { if (!m.isDiagonal()) throw new IllegalArgumentException("Source is not a diagonal matrix!"); return wrap(m.getLeadingDiagonal().toDoubleArray()); } public static DiagonalMatrix wrap(double[] data) { return new DiagonalMatrix(data); } public static DiagonalMatrix wrap(Vector data) { return new DiagonalMatrix(data); } @Override public double trace() { double result=0.0; for (int i=0; i<dimensions; i++) { result+=data[i]; } return result; } @Override public double diagonalProduct() { double result=1.0; for (int i=0; i<dimensions; i++) { result*=data[i]; } return result; } @Override public double elementSum() { return lead.elementSum(); } @Override public long nonZeroCount() { return lead.nonZeroCount(); } @Override public double get(int row, int column) { if (row!=column) { if ((row<0)||(row>=dimensions)) throw new IndexOutOfBoundsException(ErrorMessages.position(row,column)); return 0.0; } return data[row]; } @Override public double unsafeGet(int row, int column) { if (row!=column) return 0.0; return data[row]; } @Override public void set(int row, int column, double value) { if (row!=column) { if (value!=0.0) throw new UnsupportedOperationException(ErrorMessages.notFullyMutable(this, row, column)); } else { data[row]=value; } } @Override public void unsafeSet(int row, int column, double value) { data[row]=value; } @Override public boolean isMutable() { return true; } @Override public boolean isFullyMutable() { return dimensions<=1; } @Override public void multiply(double factor) { lead.multiply(factor); } @Override public DiagonalMatrix multiplyCopy(double factor) { double[] newData=DoubleArrays.copyOf(data); DoubleArrays.multiply(newData, factor); return wrap(newData); } @Override public double calculateElement(int i, AVector v) { return data[i]*v.unsafeGet(i); } @Override public double calculateElement(int i, Vector v) { return data[i]*v.unsafeGet(i); } @Override public void transform(Vector source, Vector dest) { int rc = rowCount(); int cc = rc; if (source.length()!=cc) throw new IllegalArgumentException(ErrorMessages.wrongSourceLength(source)); if (dest.length()!=rc) throw new IllegalArgumentException(ErrorMessages.wrongDestLength(dest)); double[] sdata=source.getArray(); double[] ddata=dest.getArray(); for (int i = 0; i < rc; i++) { ddata[i]=sdata[i]*this.data[i]; } } @Override public void transformInPlace(AVector v) { if (v instanceof ADenseArrayVector) { transformInPlace((ADenseArrayVector) v); return; } if (v.length()!=dimensions) throw new IllegalArgumentException("Wrong length vector: "+v.length()); for (int i=0; i<dimensions; i++) { v.unsafeSet(i,v.unsafeGet(i)*data[i]); } } @Override public void transformInPlace(ADenseArrayVector v) { double[] dest=v.getArray(); int offset=v.getArrayOffset(); DoubleArrays.arraymultiply(data, 0, dest, offset, dimensions); } @Override public boolean isIdentity() { for (int i=0; i<dimensions; i++) { if (data[i]!=1.0) return false; } return true; } @Override public boolean isBoolean() { return DoubleArrays.isBoolean(data, 0, dimensions); } @Override public boolean isZero() { return DoubleArrays.isZero(data); } @Override public DiagonalMatrix clone() { DiagonalMatrix m=new DiagonalMatrix(data); return m; } @Override public double determinant() { return DoubleArrays.elementProduct(data, 0, dimensions); } @Override public DiagonalMatrix inverse() { double[] newData=new double[dimensions]; DoubleArrays.reciprocal(newData); return new DiagonalMatrix(newData); } @Override public double getDiagonalValue(int i) { return data[i]; } @Override public double unsafeGetDiagonalValue(int i) { return data[i]; } @Override public Vector getLeadingDiagonal() { return lead; } @Override public AMatrix innerProduct(AMatrix a) { if (a instanceof ADiagonalMatrix) { return innerProduct((ADiagonalMatrix) a); } return super.innerProduct(a); } public AMatrix innerProduct(ADiagonalMatrix a) { if (!(a instanceof DiagonalMatrix)) return a.innerProduct(this); if (!(dimensions==a.dimensions)) throw new IllegalArgumentException(ErrorMessages.mismatch(this, a)); DiagonalMatrix result=DiagonalMatrix.create(this.data); result.lead.multiply(a.getLeadingDiagonal()); return result; } @Override public DiagonalMatrix exactClone() { return DiagonalMatrix.create(data); } @Override public void validate() { if (dimensions!=data.length) throw new VectorzException("dimension mismatch: "+dimensions); super.validate(); } }