/** RobustMath.java
*
* @author Sunita Sarawagi
* @version 1.3
*/
package iitb.CRF;
import iitb.CRF.Trainer.SumFunc;
import java.util.TreeSet;
import cern.colt.function.tdouble.DoubleDoubleFunction;
import cern.colt.function.tdouble.IntIntDoubleFunction;
import cern.colt.matrix.tdouble.DoubleMatrix1D;
import cern.colt.matrix.tdouble.DoubleMatrix2D;
public class RobustMath {
public static double LOG0 = -1*Double.MAX_VALUE;
public static double LOG2 = 0.69314718055;
static final double MINUS_LOG_MINVAL=-1*Math.log(Double.MIN_VALUE);
static final double MINUS_LOG_EPSILON = 30; //-1*Math.log(Double.MIN_VALUE);
public static boolean useCache = true;
public static double maxError=Double.NEGATIVE_INFINITY;
public static double maxErrorAtVal = 0;
public static int numInvoke=0;
static class LogExpCache {
static int CUT_OFF = 7;
static int NUM_FINE = 100000;
static int NUM_COARSE = 5000;
static double vals[] = new double[CUT_OFF*NUM_FINE+((int)MINUS_LOG_EPSILON-CUT_OFF)*NUM_COARSE+2];
static {
for(int i = vals.length-1; i >= 0; vals[i--]=-1);
}
static double lookupAddErr(double val) {
numInvoke++;
double retval = lookupAdd(val);
double actual = Math.log(Math.exp(-1*val) + 1.0);
double err = Math.abs(retval-actual);
if (err > maxError) {
maxError=err;
maxErrorAtVal=val;
System.out.println("MaxError " + maxError + " "+val + " "+numInvoke);
}
return retval;
}
static double lookupAdd(double val) {
if (!useCache)
return Math.log(Math.exp(-1*val) + 1.0);
int index = 0;
//assert ((val < MINUS_LOG_EPSILON) && (val > 0));
if (val < CUT_OFF) {
index = (int)Math.rint(val*NUM_FINE);
} else {
index = NUM_FINE*CUT_OFF + (int)Math.rint((val-CUT_OFF)*NUM_COARSE);
}
if (vals[index] < 0) {
vals[index] = Math.log(Math.exp(-1*val) + 1.0);
}
return vals[index];
}
//
// Trial code for linear interpolation-based caching of values...that did not work
static double endpts[]=new double[2];
static double cvals[] = null;//new double[CUT_OFF*NUM_FINE+((int)MINUS_LOG_EPSILON-CUT_OFF)*NUM_COARSE+2];
static double lookupAddWorse(double val) {
if (!useCache)
return Math.log(Math.exp(-1*val) + 1.0);
int index = 0;
//assert ((val < MINUS_LOG_EPSILON) && (val > 0));
if (val < CUT_OFF) {
index = (int)Math.floor(val*NUM_FINE);
} else {
index = NUM_FINE*CUT_OFF + (int)Math.floor((val-CUT_OFF)*NUM_COARSE);
}
for (int k = 0; k < 2; k++) {
double vi=val;
int i1=index+k;
if (i1 < NUM_FINE*CUT_OFF)
vi = i1/(double)NUM_FINE;
else
vi = CUT_OFF + (i1-NUM_FINE*CUT_OFF)/(double)NUM_COARSE;
endpts[k]=vi;
if (cvals[i1] <= 0)
cvals[i1] = Math.log(Math.exp(-1*vi) + 1.0);
}
double a = (val-endpts[0])/(endpts[1]-endpts[0]);
double retval = cvals[index]*a+cvals[index+1]*(1-a);
System.out.println((retval-Math.log(Math.exp(-1*val) + 1.0))+ " "+(lookupAdd(val)-Math.log(Math.exp(-1*val) + 1.0)));
return retval;
}
};
public static double logSumExp(double v1, double v2) {
if (Math.abs(v1-v2) < Double.MIN_VALUE)
return v1 + LOG2;
double vmin = Math.min(v1,v2);
double vmax = Math.max(v1,v2);
if ( vmax > vmin + MINUS_LOG_EPSILON ) {
return vmax;
} else {
return vmax + LogExpCache.lookupAdd(vmax-vmin);
/*
double retval = vmax + Math.log(Math.exp(vmin-vmax) + 1.0);
//System.out.println((vmax-vmin) + " " + (retval-vmax));
return retval;
*/
}
}
static class LogSumExp implements DoubleDoubleFunction {
public double apply(double v1, double v2) {
return logSumExp(v1,v2);
}
};
public static LogSumExp logSumExpFunc = new LogSumExp();
//TODO: Should TreeSet<Double> be replaced with a Trove type?
static void addNoDups(TreeSet<Double> vec, double v) {
Double val = new Double(v);
if (!vec.add(val)) {
vec.remove(val);
addNoDups(vec, val.doubleValue()+LOG2);
}
}
public static double logSumExp(TreeSet<Double> logProbVector) {
while ( logProbVector.size() > 1 ) {
double lp0 = logProbVector.first();
logProbVector.remove(logProbVector.first());
double lp1 = logProbVector.first();
logProbVector.remove(logProbVector.first());
addNoDups(logProbVector,logSumExp(lp0,lp1));
}
if (logProbVector.size() > 0)
return ((Double)logProbVector.first()).doubleValue();
return RobustMath.LOG0;
}
// matrix stuff for the older version..
public static double logSumExp(DoubleMatrix1D logProb) {
TreeSet<Double> logProbVector = new TreeSet<Double>();
for ( int lpx = 0; lpx < logProb.size(); lpx++ )
if (logProb.getQuick(lpx) != RobustMath.LOG0)
addNoDups(logProbVector,logProb.getQuick(lpx));
return logSumExp(logProbVector);
}
public static double logSumExp(double[] ds) {
TreeSet<Double> logProbVector = new TreeSet<Double>();
for ( int lpx = 0; lpx < ds.length; lpx++ )
if (ds[lpx] != RobustMath.LOG0)
addNoDups(logProbVector,ds[lpx]);
return logSumExp(logProbVector);
}
static void logSumExp(DoubleMatrix1D v1, DoubleMatrix1D v2) {
for (int i = 0; i < v1.size(); i++) {
v1.set(i,logSumExp(v1.get(i), v2.get(i)));
}
}
public static double logMinusExp(double v1, double v2) {
if (v1 - Double.MIN_VALUE < v2)
return -1*MINUS_LOG_MINVAL;
// throw new Exception("Cannot take log of negative numbers");
double vmin = v2;
double vmax = v1;
if (vmax > vmin + MINUS_LOG_MINVAL) {
return vmax;
} else {
return vmax + Math.log(1.0 - Math.exp(vmin - vmax));
}
}
static class LogMult implements IntIntDoubleFunction {
DoubleMatrix2D M;
DoubleMatrix1D z;
double lalpha;
boolean transposeA;
DoubleMatrix1D y;
int cnt;
public double apply(int i, int j, double val) {
int r = i;
int c = j;
if (transposeA) {
r = j;
c = i;
}
z.set(r, RobustMath.logSumExp(z.get(r), M.get(i,j)+y.get(c)+lalpha));
return val;
}
};
static LogMult logMult = new LogMult();
public static DoubleMatrix1D logMult(DoubleMatrix2D M, DoubleMatrix1D y, DoubleMatrix1D z, double alpha, double beta, boolean transposeA) {
// z = alpha * A * y + beta*z
double lalpha = 0;
if (alpha != 1)
lalpha = Math.log(alpha);
if (beta != 0) {
if (beta != 1) {
double lbeta = Math.log(beta);
for (int i = 0; i < z.size(); z.set(i,z.get(i)+lbeta),i++);
}
} else {
z.assign(RobustMath.LOG0);
}
// in log domain this becomes:
logMult.M = M;
logMult.z = z;
logMult.lalpha = lalpha;
logMult.transposeA = transposeA;
logMult.y = y;
logMult.cnt=0;
M.forEachNonZero(logMult);
// System.out.println("Matrix "+M.size()+" "+M.columns()+ " "+logMult.cnt);
return z;
}
public static DoubleMatrix1D logMult(DoubleMatrix2D M, DoubleMatrix1D y, DoubleMatrix1D z, double alpha, double beta, boolean transposeA, EdgeGenerator edgeGen) {
// z = alpha * A * y + beta*z
// in log domain this becomes:
double lalpha = 0;
if (alpha != 1)
lalpha = Math.log(alpha);
if (beta != 0) {
if (beta != 1) {
for (int i = 0; i < z.size(); z.set(i,z.get(i)+Math.log(beta)),i++);
}
} else {
z.assign(LOG0);
}
for (int j = 0; j < M.columns(); j++) {
for (int i = (edgeGen==null?j:edgeGen.first(j)); i < M.rows(); i = (edgeGen==null)?i+1:edgeGen.next(j,i)) {
int r = i;
int c = j;
if (transposeA) {
r = j;
c = i;
}
z.setQuick(r, logSumExp(z.getQuick(r), M.getQuick(i,j)+y.get(c)+lalpha));
}
}
return z;
}
static DoubleMatrix1D Mult(DoubleMatrix2D M, DoubleMatrix1D y, DoubleMatrix1D z, double alpha, double beta, boolean transposeA, EdgeGenerator edgeGen) {
// z = alpha * A * y + beta*z
for (int i = 0; i < z.size(); z.set(i,z.get(i)*beta),i++);
for (int j = 0; j < M.columns(); j++) {
for (int i = (edgeGen==null)?j:edgeGen.first(j); i < M.rows(); i = (edgeGen==null)?i+1:edgeGen.next(j,i)) {
int r = i;
int c = j;
if (transposeA) {
r = j;
c = i;
}
z.set(r, z.getQuick(r) + M.getQuick(i,j)*y.getQuick(c)*alpha);
}
}
return z;
}
public static void main(String args[]) {
// double vals[] = new double[]{10.172079, 7.452882, 2.429751, 7.452882, 10.818797, 8.573773, 19.215824};
/*double vals[] = new double[]{2.883626, 1.670196, 0.553112, 1.670196, -0.935964, 1.864568, 2.064754};
TreeSet vec = new TreeSet();
double trueSum = 0;
for (int i = 0; i < vals.length; i++) {
addNoDups(vec,vals[i]);
trueSum += Math.exp(vals[i]);
}
double sum = logSumExp(vec);
*/
System.out.println(logSumExp(Double.parseDouble(args[0]), Double.parseDouble(args[1])));
}
/**
* @param d
* @return
*/
public static double exp(double d) {
if (Double.isInfinite(d) || ((d < 0) && (Math.abs(d) > MINUS_LOG_EPSILON)))
return 0;
//if ((d > 0) && (d < Double.MIN_VALUE))
// return 1;
//System.out.println(d + " " + Math.exp(d));
return Math.exp(d);
}
/**
* @param val
* @return
*/
public static double log(float val) {
return (Math.abs(val-1) < Double.MIN_VALUE)?0:Math.log(val);
}
public static void logMatrixMult(DoubleMatrix2D result, DoubleMatrix2D A, DoubleMatrix2D B, DoubleMatrix1D ri, boolean noMatrixMult) {
DoubleDoubleFunction sumFunc = new SumFunc();
if (noMatrixMult)
result.assign(B);
else {
for (int i = 0; i < A.rows(); i++) {
for (int j = 0; j < B.columns(); j++) {
double value = LOG0;
for (int k = 0; k < B.rows(); k++) {
value = logSumExp(value, A.get(i,k)+B.get(k, j));
}
result.set(i, j, value);
}
}
}
for (int i = 0; i < A.rows(); i++) {
result.viewRow(i).assign(ri, sumFunc);
}
}
};