package abra.cadabra; import java.util.Random; /** * Basic Fisher's exact test implementation. * * Caches factorial values in log space across calculations, thus speeding things up (a bit). * * @author lmose */ public class FishersExactTest { private static int MAX_SIZE = 5000; // Cache of factorial values in log space private static double[] factorialCache = new double[MAX_SIZE+1]; static { init(); } private static void init() { for (int i=1; i<=MAX_SIZE; i++) { factorialCache[i] = factorialCache[i-1] + Math.log(i); } } public double oneTailedTest(int normalRef, int normalAlt, int tumorRef, int tumorAlt) { int row1Col1 = normalRef; int row1Col2 = normalAlt; int row2Col1 = tumorRef; int row2Col2 = tumorAlt; int n = row1Col1 + row1Col2 + row2Col1 + row2Col2; if (n > MAX_SIZE) { double scale = (double) MAX_SIZE / (double) n; row1Col1 = (int) (row1Col1 * scale); row1Col2 = (int) (row1Col2 * scale); row2Col1 = (int) (row2Col1 * scale); row2Col2 = (int) (row2Col2 * scale); n = row1Col1 + row1Col2 + row2Col1 + row2Col2; } int row1Sum = row1Col1 + row1Col2; int row2Sum = row2Col1 + row2Col2; int col1Sum = row1Col1 + row2Col1; int col2Sum = row1Col2 + row2Col2; double numerator = factorialCache[row1Sum] + factorialCache[row2Sum] + factorialCache[col1Sum] + factorialCache[col2Sum]; double pObserved = getPForTable(row1Col1, row1Col2, row2Col1, row2Col2, n, numerator); double pValue = pObserved; while (row1Col2 > 0 && row2Col1 > 0) { row1Col1++; row1Col2--; row2Col1--; row2Col2++; double nextP = getPForTable(row1Col1, row1Col2, row2Col1, row2Col2, n, numerator); if (nextP <= pObserved) { pValue += nextP; } } // Cap p-value at 1 to guard against rounding errors return Math.min(pValue, 1.0); } //TODO: Extract shared code public double twoTailedTest(int normalRef, int normalAlt, int tumorRef, int tumorAlt) { int row1Col1 = normalRef; int row1Col2 = normalAlt; int row2Col1 = tumorRef; int row2Col2 = tumorAlt; int n = row1Col1 + row1Col2 + row2Col1 + row2Col2; if (n > MAX_SIZE) { double scale = (double) MAX_SIZE / (double) n; row1Col1 = (int) (row1Col1 * scale); row1Col2 = (int) (row1Col2 * scale); row2Col1 = (int) (row2Col1 * scale); row2Col2 = (int) (row2Col2 * scale); n = row1Col1 + row1Col2 + row2Col1 + row2Col2; } int row1Col1Start = row1Col1; int row1Col2Start = row1Col2; int row2Col1Start = row2Col1; int row2Col2Start = row2Col2; int row1Sum = row1Col1 + row1Col2; int row2Sum = row2Col1 + row2Col2; int col1Sum = row1Col1 + row2Col1; int col2Sum = row1Col2 + row2Col2; double numerator = factorialCache[row1Sum] + factorialCache[row2Sum] + factorialCache[col1Sum] + factorialCache[col2Sum]; double pObserved = getPForTable(row1Col1, row1Col2, row2Col1, row2Col2, n, numerator); double pValue = pObserved; while (row1Col2 > 0 && row2Col1 > 0) { row1Col1++; row1Col2--; row2Col1--; row2Col2++; double nextP = getPForTable(row1Col1, row1Col2, row2Col1, row2Col2, n, numerator); if (nextP <= pObserved) { pValue += nextP; } } // Now the other way... row1Col1 = row1Col1Start; row1Col2 = row1Col2Start; row2Col1 = row2Col1Start; row2Col2 = row2Col2Start; while (row1Col1 > 0 && row2Col2 > 0) { row1Col1--; row1Col2++; row2Col1++; row2Col2--; double nextP = getPForTable(row1Col1, row1Col2, row2Col1, row2Col2, n, numerator); if (nextP <= pObserved) { pValue += nextP; } } // Cap p-value at 1 to guard against rounding errors return Math.min(pValue, 1.0); } private double getPForTable(int r1c1, int r1c2, int r2c1, int r2c2, int n, double numerator) { //TODO: Remove this as optimization if ((r1c1 + r1c2 + r2c1 + r2c2) != n) throw new IllegalArgumentException("Invalid contigency table"); double denominator = factorialCache[r1c1] + factorialCache[r1c2] + factorialCache[r2c1] + factorialCache[r2c2] + factorialCache[n]; return Math.exp(numerator - denominator); } private static int nextRand(Random r) { return r.nextInt(10000); } public static void main(String[] args) { // int nr = 1500; int na = 110; int tr = 1400; int ta = 1100; int nr = 709; int na = 20; int tr = 711; int ta = 85; FishersExactTest t = new FishersExactTest(); double p = t.oneTailedTest(nr, na, tr, ta); // Random r = new Random(); // // long s = System.currentTimeMillis(); // // for (int i=0; i<10000; i++) { // double p = t.oneTailedTest(nr + nextRand(r), na + nextRand(r), tr + nextRand(r), ta + nextRand(r)); // if (i%1000 == 0) { // System.out.println(p); // } // } // // long e = System.currentTimeMillis(); System.out.println("p: " + p); System.out.println("phred: " + (-10 * Math.log10(p))); // // System.out.println(e-s); p = t.twoTailedTest(nr, na, tr, ta); System.out.println("2 tailed: " + p); } }