package info.kalendra.math;
import java.util.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* based on:
* Accurate Sum and Dot Product by T. Ogita, et. al., published in the SIAM Journal of Scientific Computing, Vol 26 , No 6, pp. 1955-1988
* @author Eric Kalendra
*
*/
public class AccurateMath {
static Logger log = LoggerFactory.getLogger(AccurateMath.class);
//matlab says 2^-52, paper says 2^-53??
static final double epsDouble = Math.pow(2, -53);
//underflow unit...
static final double etaDouble = Math.pow(2, -1074);
final int T = 53;
final int S = 27;
final double FACTOR = Math.pow(2, S) + 1;
// static final double epsFloat = Math.pow(2, -23);
/**
* Dot product algorithm in K-fold working precision, K ≥ 3.
* @param x
* @param y
* @param k must be >= 3
* @return double
*/
public double dotK(double[] x, double[] y, int k){
double res = 0;
int n = x.length;
if(n != y.length){
throw new IllegalArgumentException("length of x.length must = y.length");
}
double[] r = new double[2*n];
double[] temp;
double p, h;
temp = twoProduct(x[0], y[0]);
p = temp[0];
r[0] = temp[1];
for(int i = 2; i <= n; i++){
temp = twoProduct(x[i-1], y[i-1]);
h = temp[0];
r[i-1] = temp[1];
temp = twoSum(p, h);
p = temp[0];
r[n+i-1-1] = temp[1];
}
r[2*n-1]=p;
res = sumK(r,k-1);
return(res);
}
/**
* Dot product in twice the working precision with error bound including underflow.
* @param x
* @param y
* @return double[]
*/
public double[] dot2Err(double[] x, double[] y){
double res = 0;
int n = x.length;
if(n != y.length){
throw new IllegalArgumentException("length of x.length must = y.length");
}
if(2*n*epsDouble >= 1){
throw new RuntimeException("Inclusion failed (error from paper)");
}
double[] temp;
double p, s, h, r, q, e, t;
temp = twoProduct(x[0], y[0]);
p = temp[0];
s = temp[1];
e = Math.abs(s);
for(int i = 2; i <= n; i++){
temp = twoProduct(x[i-1], y[i-1]);
h = temp[0];
r = temp[1];
temp = twoSum(p, h);
p = temp[0];
q = temp[1];
t = q + r;
s += t;
e += Math.abs(t);
}
res = p + s;
double delta, alpha;
delta = (n*epsDouble)/(1 - 2*n*epsDouble);
alpha = epsDouble*Math.abs(res) + (delta*e + 3*etaDouble/epsDouble);
double err = alpha / (1 - 2*epsDouble);
double[] out = new double[] {res, err};
return(out);
}
/**
* XBLAS quadruple precision dot product
* 37n flops
* accurate results for condition numbers up to some 10^32
* @param x
* @param y
* @return double
*/
public double dotXBLASS(double[] x, double[] y){
double res = 0;
int n = x.length;
if(n != y.length){
throw new IllegalArgumentException("length of x.length must = y.length");
}
double s = 0;
double t = 0;
double h;
double r;
double[] temp;
double[] sArray;
double[] tArray;
for(int i = 1; i <= n; i++){
temp = twoProduct(x[i-1], y[i-1]);
h = temp[0];
r = temp[1];
sArray = twoSum(s,h);
tArray = twoSum(t, r);
sArray[1] = sArray[1] + tArray[0];
temp = fastTwoSum(sArray[0], sArray[1]);
tArray[0] = temp[0];
sArray[1] = temp[1];
tArray[1] = tArray[1] + sArray[1];
temp = fastTwoSum(tArray[0], tArray[1]);
s = temp[0];
t = temp[1];
}
res = s;
return(res);
}
/**
* Equivalent formulation of dot2
* 25n flops
* @param x
* @param y
* @return double
*/
public double dot2s(double[] x, double[] y){
double res = 0;
int n = x.length;
if(n != y.length){
throw new IllegalArgumentException("length of x.length must = y.length");
}
double[] temp;
double[] p, s, h, r, q;
p = new double[n];
s = new double[n];
h = new double[n];
r = new double[n];
q = new double[n];
temp = twoProduct(x[0], y[0]);
p[0] = temp[0];
s[0] = temp[1];
for(int i = 2; i <= n; i++){
temp = twoProduct(x[i-1], y[i-1]);
h[i-1] = temp[0];
r[i-1] = temp[1];
temp = twoSum(p[i-1-1], h[i-1]);
p[i-1] = temp[0];
q[i-1] = temp[1];
s[i-1] = s[i-1-1] + (q[i-1] + r[i-1]);
}
res = p[n-1] + s[n-1];
return(res);
}
/**
* Dot product in twice the working precision.
* @param x
* @param y
* @return double
*/
public double dot2(double[] x, double[] y){
double res = 0;
int n = x.length;
if(n != y.length){
throw new IllegalArgumentException("length of x.length must = y.length");
}
double[] temp;
double p, s, h, r, q;
temp = twoProduct(x[0], y[0]);
p = temp[0];
s = temp[1];
for(int i = 2; i <= n; i++){
temp = twoProduct(x[i-1], y[i-1]);
h = temp[0];
r = temp[1];
temp = twoSum(p, h);
p = temp[0];
q = temp[1];
s += (q + r);
}
res = p + s;
return(res);
}
/**
* A first dot product algorithm
* @param x
* @param y
* @param k
* @return double
*/
public double dot(double[] x, double[] y, int k){
double res = 0;
int n = x.length;
if(n != y.length){
throw new IllegalArgumentException("length of x.length must = y.length");
}
double[] r = new double[2*n];
double[] temp;
for(int i = 1; i <= n; i++){
temp = twoProduct(x[i-1], y[i-1]);
r[i - 1] = temp[0];
r[n + i - 1] = temp[1];
}
res = sumK(r, k);
//printVec(r);
return(res);
}
/**
* Equivalent formulation of sumK – vertical mode
* document has error for k = 2
* @param p
* @param k
* @return double
*/
public double sumKvert(double[] p, int k){
double out = 0;
int n = p.length;
int kUse = Math.min(k, n);
double s = 0;
double[] q = new double[kUse];
double[] temp;
for(int i = 1; i <= (kUse-1); i++){
s = p[i-1];
for( int kCounter = 1; kCounter <= (i - 1); kCounter++){
temp = twoSum(q[kCounter-1],s);
q[kCounter-1] = temp[0];
s = temp[1];
}
q[i-1] = s;
//seems to fix bug in the document
if(kUse == 2){
s = 0;
}
}
// log.debug("s: {}", s);
// printVec(q);
double alpha =0;
for(int i = kUse; i <= n; i++){
alpha = p[i-1];
for( int kCounter = 1; kCounter <= (kUse - 1); kCounter++){
temp = twoSum(q[kCounter-1],alpha);
q[kCounter-1] = temp[0];
alpha = temp[1];
}
s += alpha;
}
// log.debug("s: {}", s);
// if(kUse > 2)
for( int j = 1; j <= kUse - 2 ; j++){
alpha = q[j-1];
for( int kCounter = j + 1; kCounter <= (kUse - 1); kCounter++){
temp = twoSum(q[kCounter-1],alpha);
q[kCounter-1] = temp[0];
alpha = temp[1];
}
s += alpha;
}
log.debug("s: {}", s);
log.debug("q[kUse-1-1]: {}", q[kUse - 1 -1]);
// printVec(q);
out = s + q[kUse - 1 -1];
return(out);
}
/**
* Summation as in K-fold precision by (K − 1)-fold error-free vector transformation.
* @param p
* @param k
* @return double
*/
public double sumK(double[] p, int k){
double out = 0;
double[] ptemp = Arrays.copyOf(p,p.length);
for(int kCounter = 1; kCounter <= (k-1); kCounter++){
ptemp = vecSum(ptemp);
}
out = sum(ptemp);
return(out);
}
/**
* XBLAS quadruple precision summation
* @param p
* @return double
*/
public double sumXBLAS(double[] p){
double s = 0;
double t = 0;
double[] tArray;
double[] tempArray;
for(int i = 0; i < p.length; i++){
tArray = twoSum(s,p[i]);
tArray[1] += t;
tempArray = fastTwoSum(tArray[0],tArray[1]);
s = tempArray[0];
t = tempArray[1];
}
return(s);
}
public double sum2(double[] p){
double out = 0;
double[] p2;
p2 = vecSum(p);
out = sum(p2);
return(out);
}
public double[] vecSum(double[] p){
double[] outp = new double[p.length];
outp[0] = p[0];
double[] temp;
for(int i =1; i < p.length; i++){
temp = twoSum(p[i],outp[i-1]);
outp[i] = temp[0];
outp[i-1] = temp[1];
}
return(outp);
}
public double sum(double[] p){
double out = 0;
for(int i = 0; i < p.length; i++){
out += p[i];
}
return(out);
}
public void printVec(double[] p){
for( int i = 0; i<p.length - 1; i++){
System.out.print("" + p[i] + ", ");
}
System.out.println("" + p[p.length-1]);
}
public double sum2s(double[] p){
double pi = p[0];
double sigma = 0;
double[] temp;
for( int i = 1; i < p.length; i++){
temp = twoSum(pi,p[i]);
pi = temp[0];
sigma += temp[1];
}
return(pi + sigma);
}
/**
* Error-free transformation of the product of two floating point numbers.
* Verified using octave, but does not seem error-free
* @param a
* @param b
* @return double[]
*/
public double[] twoProduct(double a, double b){
double[] out = new double[2];
double x = a * b;
double[] splita = split(a);
double[] splitb = split(b);
double y = (splita[1] * splitb[1]) - ((( x - splita[0]*splitb[0]) - splita[1]*splitb[0]) - splita[0]*splitb[1]);
/*
double y = 0;
y += splita[1]*splitb[0];
y += splita[0]*splitb[1];
y += (splita[1] * splitb[1]);
y += - (x - splita[0]*splitb[0]);
*/
log.debug("twoProduct: x: {}",x);
log.debug("twoProduct: splita[0]: {}",splita[0]);
log.debug("twoProduct: splita[1]: {}",splita[1]);
log.debug("twoProduct: splitb[0]: {}",splitb[0]);
log.debug("twoProduct: splitb[1]: {}",splitb[1]);
log.debug("twoProduct: splita[1] * splitb[1]: {}",splita[1] * splitb[1]);
log.debug("twoProduct: x - splita[0]*splitb[0]: {}",x - splita[0]*splitb[0]);
log.debug("twoProduct: splita[1]*splitb[0]: {}",splita[1]*splitb[0]);
log.debug("twoProduct: splita[0]*splitb[1]: {}",splita[0]*splitb[1]);
log.debug("twoProduct: splita[0]*splitb[0]: {}", splita[0]*splitb[0]);
log.debug("twoProduct: sumofsplit: {}", (splita[0]*splitb[0])+(splita[1]*splitb[0])+(splita[0]*splitb[1])+(splita[1] * splitb[1]));
//double[] test = new double[] {(splita[0]*splitb[0])-x,(splita[1]*splitb[0]),(splita[0]*splitb[1]),(splita[1] * splitb[1])};
//printVec(vecSum(test));
log.debug("twoProduct: y: {}",y);
out[0] = x;
out[1] = y;
return(out);
}
public double[] split(double a){
double[] out = new double[2];
double c = FACTOR * a;
double x = c - (c - a);
double y = a - x;
out[0] = x;
out[1] = y;
return(out);
}
public double[] twoSum(double a, double b){
double[] out = new double[2];
double x = a + b;
double z = x - a;
double y = (a - (x - z)) + (b - z);
out[0] = x;
out[1] = y;
return out;
}
public double[] fastTwoSumUnOrdered(double a, double b){
if( Math.abs(a) >= Math.abs(b)){
return(fastTwoSum(a, b));
} else {
//Since b > a, reverse order
return(fastTwoSum(b, a));
}
}
public double[] fastTwoSum(double a, double b){
double[] out = new double[2];
double x = a + b;
double y = (a - x) + b;
out[0] = x;
out[1] = y;
return(out);
}
/**
* @param args
*/
public static void main(String[] args) {
AccurateMath testob = new AccurateMath();
double a = 100.00000001;
double b = 800.0002;
double[] result;
double temp;
testob.printVec(testob.split(a));
System.out.println("sum of split a: " + testob.sum(testob.split(a)));
testob.printVec(testob.split(b));
System.out.println("sum of split b: " + testob.sum(testob.split(b)));
result = testob.twoProduct(a, b);
System.out.printf("x: %30.18f \n", result[0]);
System.out.printf("y: %30.18f \n", result[1]);
temp = result[0] + result[1];
System.out.printf("x + y: %30.18f \n", temp);
System.out.printf("a * b: %30.18f \n", (a*b));
double[] array = new double[] {1.1,2.2,3.3,4.4,5.5};
temp = testob.sum2s(array);
System.out.println("Sum from sum2s: " + temp);
System.out.println("Sum from basic sum: " + testob.sum(array));
System.out.println("Sum from sumXBLAS: " + testob.sumXBLAS(array));
System.out.println("Sum from sumK: " + testob.sumK(array,3));
System.out.println("Sun from sumKvert: " + testob.sumKvert(array, 2));
testob.printVec(array);
result = testob.vecSum(array);
testob.printVec(result);
System.out.println("Dot array,array: " + testob.dot(array, array, 2));
System.out.println("Dot2 array,array: " + testob.dot2(array, array));
testob.printVec(testob.dot2Err(array, array));
System.out.println("Dot2s array,array: " + testob.dot2s(array, array));
System.out.println("dotXBLASS array,array: " + testob.dotXBLASS(array, array));
System.out.println("dotK array,array: " + testob.dotK(array, array, 3));
}
}