/** * @author Sunita Sarawagi * @since 1.2 * @version 1.3 */ package iitb.CRF; import java.util.TreeSet; import cern.colt.function.tdouble.DoubleDoubleFunction; import cern.colt.function.tdouble.IntDoubleFunction; import cern.colt.matrix.tdouble.DoubleMatrix1D; import cern.colt.matrix.tdouble.DoubleMatrix2D; import cern.colt.matrix.tdouble.impl.DenseDoubleMatrix1D; import cern.colt.matrix.tdouble.impl.DenseDoubleMatrix2D; //this needs to be done to support an efficient sparse implementation //of matrices in the log-space public class LogDenseDoubleMatrix1D extends DenseDoubleMatrix1D { /** * */ private static final long serialVersionUID = 5041548544154364582L; static double map(double val) { if (val == RobustMath.LOG0) return 0; if (val == 0) return Double.MIN_VALUE; return val; } static double reverseMap(double val) { if (val == 0) { return RobustMath.LOG0; } if (val == Double.MIN_VALUE) return 0; return val; } public LogDenseDoubleMatrix1D(int numY) {super(numY);} public DoubleMatrix1D assign(double val) { return super.assign(map(val)); } public void set(int row, double val) { super.set(row,map(val)); } public double get(int row) { return reverseMap(super.get(row)); } public double zSum() { TreeSet<Double> logProbVector = new TreeSet<Double>(); // TODO for (int row = 0; row < size(); row++) { if (getQuick(row) != 0) RobustMath.addNoDups(logProbVector,get(row)); } return RobustMath.logSumExp(logProbVector); } // WARNING: this is only correct for functions that leave the infinity unchanged. public DoubleMatrix1D forEachNonZero(IntDoubleFunction func) { for (int y = 0; y < size(); y++) { if (getQuick(y) != 0) setQuick(y,func.apply(y,get(y))); } return this; } // WARNING: this is only correct for functions that leave the infinity unchanged. public DoubleMatrix1D assign(DoubleMatrix1D v2, DoubleDoubleFunction func) { // TODO.. for (int row = 0; row < size(); row++) { if ((v2.getQuick(row) != 0) || (getQuick(row) != 0)) set(row,func.apply(get(row), v2.get(row))); } return this; } public boolean equals(Object arg) { DoubleMatrix1D mat = (DoubleMatrix1D)arg; for (int row = (int) (size()-1); row >= 0; row--) if (Math.abs(mat.get(row)-get(row))/Math.abs(mat.get(row)) > 0.0001) return false; return true; } }; class LogDenseDoubleMatrix2D extends DenseDoubleMatrix2D { /** * */ private static final long serialVersionUID = 7238191232756809992L; static double map(double val) { return LogSparseDoubleMatrix1D.map(val);} static double reverseMap(double val) { return LogSparseDoubleMatrix1D.reverseMap(val);} LogDenseDoubleMatrix2D(int numR, int numC) {super(numR,numC); } public DoubleMatrix2D assign(double val) { return super.assign(map(val)); } public void set(int row, int column, double val) { super.set(row,column,map(val)); } public double get(int row, int column) { return reverseMap(super.get(row,column)); } public DoubleMatrix1D zMult(DoubleMatrix1D y, DoubleMatrix1D z, double alpha, double beta, boolean transposeA) { return RobustMath.logMult(this,y,z,alpha,beta,transposeA); } };