package com.compomics.util.math;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import org.apache.commons.math.util.FastMath;
/**
* Class used to perform basic mathematical functions.
*
* @author Marc Vaudel
*/
public class BasicMathFunctions {
/**
* Cache for the base used for the log.
*/
private static double logBase = 0;
/**
* Cache for the logarithm value of the base used for the log.
*/
private static double logBaseValue;
/**
* Cache for factorials.
*/
private static final HashMap<Integer, Long> factorialsCache = new HashMap<Integer, Long>();
/**
* Returns n! as a long. Returns null if the capacity of a long is not
* sufficient (ie n higher than 20).
*
* @param n a given integer
*
* @return the corresponding factorial
*/
public static Long factorial(Integer n) {
if (n == 0) {
return 0L;
} else if (n == 1) {
return 1L;
} else if (n <= 20) {
Long result = factorialsCache.get(n);
if (result == null) {
result = estimateFactorial(n);
}
return result;
} else if (n > 20) {
throw new IllegalArgumentException("Factorial only implemented for n <= 20. Reached the maximal capacoty of an integer.");
} else if (n < 1) {
throw new ArithmeticException("Attempting to calculate the factorial of a negative number.");
}
throw new UnsupportedOperationException("Factorial not implemented for n=" + n + ".");
}
/**
* Estimates factorial in a synchronous method as part of the factorial
* method.
*
* @param n a given integer
*
* @return the corresponding factorial
*/
private static synchronized Long estimateFactorial(Integer n) {
Long result = factorialsCache.get(n);
if (result == null) {
synchronized (BasicMathFunctions.class) {
result = factorialsCache.get(n);
if (result == null) {
result = factorial(n - 1) * n;
factorialsCache.put(n, result);
}
}
}
return result;
}
/**
* Returns n!/k!, null if it cannot fit in a long.
*
* @param n a given integer
* @param k a given integer
*
* @return the corresponding factorial
*/
public static Long factorial(Integer n, Integer k) {
if (n < k) {
throw new ArithmeticException("n < k in n!/k!.");
}
if (n.equals(k)) {
return (long) 1;
} else {
if (n < 20) {
return factorial(n) / factorial(k);
}
Long nMinusOne = factorial(n - 1, k);
if (nMinusOne == null || nMinusOne > Long.MAX_VALUE / n) {
return null;
} else {
return nMinusOne * n;
}
}
}
/**
* Returns the number of k-combinations in a set of n elements. If n!/k!
* cannot fit in a long, null is returned, use BigDecimal instead (see
* BigFunctions).
*
* @param k the number of k-combinations
* @param n the number of elements
*
* @return the number of k-combinations in a set of n elements
*/
public static Long getCombination(int k, int n) {
if (k == 0) {
return (long) 1;
} else if (k < n) {
Long kInN = factorial(n, k);
Long nMinK = factorial(n - k);
if (kInN == null || nMinK == null) {
return null;
} else {
return kInN / kInN;
}
} else if (k == n) {
return (long) 1;
} else {
throw new IllegalArgumentException("n>k in combination.");
}
}
/**
* Method to estimate the median.
*
* @param ratios array of double
* @return median of the input
*/
public static double median(double[] ratios) {
Arrays.sort(ratios);
int length = ratios.length;
if (ratios.length == 1) {
return ratios[0];
}
if (length % 2 == 1) {
return ratios[(length - 1) / 2];
} else {
return (ratios[length / 2] + ratios[(length) / 2 - 1]) / 2;
}
}
/**
* Method to estimate the median.
*
* @param input ArrayList of double
* @return median of the input
*/
public static double median(ArrayList<Double> input) {
return percentile(input, 0.5);
}
/**
* Method to estimate the median of a sorted list.
*
* @param input ArrayList of double
* @return median of the input
*/
public static double medianSorted(ArrayList<Double> input) {
return percentileSorted(input, 0.5);
}
/**
* Returns the desired percentile in a given array of double. If the
* percentile is between two values a linear interpolation is done.
*
* @param input the input array
* @param percentile the desired percentile. 0.01 returns the first
* percentile. 0.5 returns the median.
*
* @return the desired percentile
*/
public static double percentile(double[] input, double percentile) {
if (percentile < 0 || percentile > 1) {
throw new IllegalArgumentException("Incorrect input for percentile: " + percentile + ". Input must be between 0 and 1.");
}
Arrays.sort(input);
int length = input.length;
if (length == 0) {
throw new IllegalArgumentException("Attempting to estimate the percentile of an empty list.");
}
if (length == 1) {
return input[0];
}
double indexDouble = percentile * (length - 1);
int index = (int) (indexDouble);
double valueAtIndex = input[index];
double rest = indexDouble - index;
if (index == input.length - 1 || rest == 0) {
return valueAtIndex;
}
return valueAtIndex + rest * (input[index + 1] - valueAtIndex);
}
/**
* Returns the desired percentile in a given list of double. If the
* percentile is between two values a linear interpolation is done. Note:
* When calculating multiple percentiles on the same list, it is advised to
* sort it and use percentileSorted.
*
* @param input the input list
* @param percentile the desired percentile. 0.01 returns the first
* percentile. 0.5 returns the median.
*
* @return the desired percentile
*/
public static double percentile(ArrayList<Double> input, double percentile) {
if (input == null) {
throw new IllegalArgumentException("Attempting to estimate the percentile of a null object.");
}
int length = input.size();
if (length == 0) {
throw new IllegalArgumentException("Attempting to estimate the percentile of an empty list.");
}
ArrayList<Double> sortedInput = new ArrayList<Double>(input);
Collections.sort(sortedInput);
return percentileSorted(sortedInput, percentile);
}
/**
* Returns the desired percentile in a given list of double. If the
* percentile is between two values a linear interpolation is done. The list
* must be sorted prior to submission.
*
* @param input the input list
* @param percentile the desired percentile. 0.01 returns the first
* percentile. 0.5 returns the median.
*
* @return the desired percentile
*/
public static double percentileSorted(ArrayList<Double> input, double percentile) {
if (percentile < 0 || percentile > 1) {
throw new IllegalArgumentException("Incorrect input for percentile: " + percentile + ". Input must be between 0 and 1.");
}
if (input == null) {
throw new IllegalArgumentException("Attempting to estimate the percentile of a null object.");
}
int length = input.size();
if (length == 0) {
throw new IllegalArgumentException("Attempting to estimate the percentile of an empty list.");
}
if (length == 1) {
return input.get(0);
}
double indexDouble = percentile * (length - 1);
int index = (int) (indexDouble);
double valueAtIndex = input.get(index);
double rest = indexDouble - index;
if (index == input.size() - 1 || rest == 0) {
return valueAtIndex;
}
return valueAtIndex + rest * (input.get(index + 1) - valueAtIndex);
}
/**
* Method estimating the median absolute deviation.
*
* @param ratios array of doubles
* @return the mad of the input
*/
public static double mad(double[] ratios) {
double[] deviations = new double[ratios.length];
double med = median(ratios);
for (int i = 0; i < ratios.length; i++) {
deviations[i] = Math.abs(ratios[i] - med);
}
return median(deviations);
}
/**
* Method estimating the median absolute deviation.
*
* @param ratios array of doubles
* @return the mad of the input
*/
public static double mad(ArrayList<Double> ratios) {
double[] deviations = new double[ratios.size()];
double med = median(ratios);
for (int i = 0; i < ratios.size(); i++) {
deviations[i] = Math.abs(ratios.get(i) - med);
}
return median(deviations);
}
/**
* Returns the log of the input in the desired base.
*
* @param input the input
* @param base the log base
*
* @return the log value of the input in the desired base.
*/
public static double log(double input, double base) {
if (base <= 0) {
throw new IllegalArgumentException("Attempting to comupute logarithm of base " + base + ".");
} else if (base != logBase) {
logBase = base;
logBaseValue = FastMath.log(base);
}
return FastMath.log(input) / logBaseValue;
}
/**
* Convenience method returning the standard deviation of a list of doubles.
* Returns 0 if the list is null or of size < 2.
*
* @param input input list
* @return the corresponding standard deviation
*/
public static double std(ArrayList<Double> input) {
if (input == null || input.size() < 2) {
return 0;
}
double result = 0;
double mean = mean(input);
for (Double x : input) {
result += Math.pow(x - mean, 2);
}
result = result / (input.size() - 1);
result = Math.sqrt(result);
return result;
}
/**
* Convenience method returning the mean of a list of doubles.
*
* @param input input list
* @return the corresponding mean
*/
public static double mean(ArrayList<Double> input) {
return sum(input) / input.size();
}
/**
* Convenience method returning the sum of a list of doubles.
*
* @param input input list
* @return the corresponding mean
*/
public static double sum(ArrayList<Double> input) {
double result = 0;
for (Double x : input) {
result += x;
}
return result;
}
/**
* Returns the population Pearson correlation r between series1 and series2.
*
* @param series1 first series to compare
* @param series2 second series to compare
*
* @return the Pearson correlation factor
*/
public static double getCorrelation(ArrayList<Double> series1, ArrayList<Double> series2) {
if (series1.size() != series2.size()) {
throw new IllegalArgumentException("Series must be of same size for correlation analysis (series 1: " + series1.size() + " elements, series 1: " + series2.size() + " elements).");
}
int n = series1.size();
if (n <= 1) {
throw new IllegalArgumentException("At least two values are required for the estimation of correlation factors (" + n + " elements).");
}
double std1 = std(series1);
double std2 = std(series2);
if (std1 == 0 && std2 == 0) {
return 1;
}
if (std1 == 0) {
std1 = std2;
}
if (std2 == 0) {
std2 = std1;
}
double mean1 = mean(series1);
double mean2 = mean(series2);
double corr = 0;
for (int i = 0; i < n; i++) {
corr += (series1.get(i) - mean1) * (series2.get(i) - mean2);
}
corr = corr / (std1 * std2);
corr = corr / (n - 1);
return corr;
}
/**
* Returns the population Pearson correlation r between series1 and series2.
* Here the correlation factor is estimated using median and percentile
* distance instead of mean and standard deviation.
*
* @param series1 the first series to inspect
* @param series2 the second series to inspect
*
* @return a robust version of the Pearson correlation factor
*/
public static double getRobustCorrelation(ArrayList<Double> series1, ArrayList<Double> series2) {
if (series1.size() != series2.size()) {
throw new IllegalArgumentException("Series must be of same size for correlation analysis (series 1: " + series1.size() + " elements, series 1: " + series2.size() + " elements).");
}
int n = series1.size();
if (n <= 1) {
throw new IllegalArgumentException("At least two values are required for the estimation of correlation factors (" + n + " elements).");
}
double std1 = (percentile(series1, 0.841) - percentile(series1, 0.159)) / 2;
double std2 = (percentile(series2, 0.841) - percentile(series2, 0.159)) / 2;
if (std1 == 0 && std2 == 0) {
return 1;
}
if (std1 == 0) {
std1 = std2;
}
if (std2 == 0) {
std2 = std1;
}
double mean1 = median(series1);
double mean2 = median(series2);
double corr = 0;
for (int i = 0; i < n; i++) {
corr += (series1.get(i) - mean1) * (series2.get(i) - mean2);
}
corr = corr / (std1 * std2);
corr = corr / (n - 1);
return corr;
}
/**
* Checks that a probability is between 0 and 1 and throws an
* IllegalArgumentException otherwise. 0.5 represents a probability of 50%.
*
* @param p the probability
*/
public static void checkProbabilityRange(double p) {
if (p < 0.0) {
throw new IllegalArgumentException("Probability <0%.");
} else if (p > 1.0) {
throw new IllegalArgumentException("Probability >100%.");
}
}
/**
* Checks that a probability is between 0 % and 100 % and throws an
* IllegalArgumentException otherwise. 50 represents a probability of 50%.
*
* @param p the probability
*/
public static void checkProbabilityRangeInPercent(double p) {
if (p < 0.0) {
throw new IllegalArgumentException("Probability <0%.");
} else if (p > 100.0) {
throw new IllegalArgumentException("Probability >100%.");
}
}
/**
* Returns an integer randomly chosen between min and max included.
*
* @param min the lower limit
* @param max the higher limit
*
* @return a random integer
*/
public static int getRandomInteger(int min, int max) {
double randomDouble = min + (Math.random() * (max - min));
if (randomDouble > max) {
return max;
}
if (randomDouble < min) {
return min;
}
return (int) Math.round(randomDouble);
}
/**
* Returns a list of n random indexes between min and max included. The list
* is not sorted.
*
* @param n the number of indexes to return
* @param min the lower limit
* @param max the higher limit
*
* @return a list of n random indexes between min and max included
*/
public static ArrayList<Integer> getRandomIndexes(int n, int min, int max) {
ArrayList<Integer> result = new ArrayList<Integer>(n);
for (int i = 0; i < n; i++) {
result.add(getRandomInteger(min, max));
}
return result;
}
}