/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.matrix.stats;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
class MultiPassStats {
private final String fieldAKey;
private final String fieldBKey;
private long count;
private Map<String, Double> means = new HashMap<>();
private Map<String, Double> variances = new HashMap<>();
private Map<String, Double> skewness = new HashMap<>();
private Map<String, Double> kurtosis = new HashMap<>();
private Map<String, HashMap<String, Double>> covariances = new HashMap<>();
private Map<String, HashMap<String, Double>> correlations = new HashMap<>();
MultiPassStats(String fieldAName, String fieldBName) {
this.fieldAKey = fieldAName;
this.fieldBKey = fieldBName;
}
@SuppressWarnings("unchecked")
void computeStats(final List<Double> fieldA, final List<Double> fieldB) {
// set count
count = fieldA.size();
double meanA = 0d;
double meanB = 0d;
// compute mean
for (int n = 0; n < count; ++n) {
// fieldA
meanA += fieldA.get(n);
meanB += fieldB.get(n);
}
means.put(fieldAKey, meanA / count);
means.put(fieldBKey, meanB / count);
// compute variance, skewness, and kurtosis
double dA;
double dB;
double skewA = 0d;
double skewB = 0d;
double kurtA = 0d;
double kurtB = 0d;
double varA = 0d;
double varB = 0d;
double cVar = 0d;
for (int n = 0; n < count; ++n) {
dA = fieldA.get(n) - means.get(fieldAKey);
varA += dA * dA;
skewA += dA * dA * dA;
kurtA += dA * dA * dA * dA;
dB = fieldB.get(n) - means.get(fieldBKey);
varB += dB * dB;
skewB += dB * dB * dB;
kurtB += dB * dB * dB * dB;
cVar += dA * dB;
}
variances.put(fieldAKey, varA / (count - 1));
final double stdA = Math.sqrt(variances.get(fieldAKey));
variances.put(fieldBKey, varB / (count - 1));
final double stdB = Math.sqrt(variances.get(fieldBKey));
skewness.put(fieldAKey, skewA / ((count - 1) * variances.get(fieldAKey) * stdA));
skewness.put(fieldBKey, skewB / ((count - 1) * variances.get(fieldBKey) * stdB));
kurtosis.put(fieldAKey, kurtA / ((count - 1) * variances.get(fieldAKey) * variances.get(fieldAKey)));
kurtosis.put(fieldBKey, kurtB / ((count - 1) * variances.get(fieldBKey) * variances.get(fieldBKey)));
// compute covariance
final HashMap<String, Double> fieldACovar = new HashMap<>(2);
fieldACovar.put(fieldAKey, 1d);
cVar /= count - 1;
fieldACovar.put(fieldBKey, cVar);
covariances.put(fieldAKey, fieldACovar);
final HashMap<String, Double> fieldBCovar = new HashMap<>(2);
fieldBCovar.put(fieldAKey, cVar);
fieldBCovar.put(fieldBKey, 1d);
covariances.put(fieldBKey, fieldBCovar);
// compute correlation
final HashMap<String, Double> fieldACorr = new HashMap<>();
fieldACorr.put(fieldAKey, 1d);
double corr = covariances.get(fieldAKey).get(fieldBKey);
corr /= stdA * stdB;
fieldACorr.put(fieldBKey, corr);
correlations.put(fieldAKey, fieldACorr);
final HashMap<String, Double> fieldBCorr = new HashMap<>();
fieldBCorr.put(fieldAKey, corr);
fieldBCorr.put(fieldBKey, 1d);
correlations.put(fieldBKey, fieldBCorr);
}
void assertNearlyEqual(MatrixStatsResults stats) {
assertEquals(count, stats.getDocCount());
assertEquals(count, stats.getFieldCount(fieldAKey));
assertEquals(count, stats.getFieldCount(fieldBKey));
// means
assertTrue(nearlyEqual(means.get(fieldAKey), stats.getMean(fieldAKey), 1e-7));
assertTrue(nearlyEqual(means.get(fieldBKey), stats.getMean(fieldBKey), 1e-7));
// variances
assertTrue(nearlyEqual(variances.get(fieldAKey), stats.getVariance(fieldAKey), 1e-7));
assertTrue(nearlyEqual(variances.get(fieldBKey), stats.getVariance(fieldBKey), 1e-7));
// skewness (multi-pass is more susceptible to round-off error so we need to slightly relax the tolerance)
assertTrue(nearlyEqual(skewness.get(fieldAKey), stats.getSkewness(fieldAKey), 1e-4));
assertTrue(nearlyEqual(skewness.get(fieldBKey), stats.getSkewness(fieldBKey), 1e-4));
// kurtosis (multi-pass is more susceptible to round-off error so we need to slightly relax the tolerance)
assertTrue(nearlyEqual(kurtosis.get(fieldAKey), stats.getKurtosis(fieldAKey), 1e-4));
assertTrue(nearlyEqual(kurtosis.get(fieldBKey), stats.getKurtosis(fieldBKey), 1e-4));
// covariances
assertTrue(nearlyEqual(covariances.get(fieldAKey).get(fieldBKey),stats.getCovariance(fieldAKey, fieldBKey), 1e-7));
assertTrue(nearlyEqual(covariances.get(fieldBKey).get(fieldAKey),stats.getCovariance(fieldBKey, fieldAKey), 1e-7));
// correlation
assertTrue(nearlyEqual(correlations.get(fieldAKey).get(fieldBKey), stats.getCorrelation(fieldAKey, fieldBKey), 1e-7));
assertTrue(nearlyEqual(correlations.get(fieldBKey).get(fieldAKey), stats.getCorrelation(fieldBKey, fieldAKey), 1e-7));
}
private static boolean nearlyEqual(double a, double b, double epsilon) {
final double absA = Math.abs(a);
final double absB = Math.abs(b);
final double diff = Math.abs(a - b);
if (a == b) { // shortcut, handles infinities
return true;
} else if (a == 0 || b == 0 || diff < Double.MIN_NORMAL) {
// a or b is zero or both are extremely close to it
// relative error is less meaningful here
return diff < (epsilon * Double.MIN_NORMAL);
} else { // use relative error
return diff / Math.min((absA + absB), Double.MAX_VALUE) < epsilon;
}
}
}