package org.geogebra.common.gui.view.data;
import java.util.ArrayList;
import org.apache.commons.math3.distribution.TDistribution;
import org.apache.commons.math3.stat.StatUtils;
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
import org.apache.commons.math3.stat.inference.TTest;
import org.geogebra.common.kernel.arithmetic.ExpressionNodeConstants;
import org.geogebra.common.kernel.geos.GeoList;
import org.geogebra.common.main.App;
import org.geogebra.common.main.Localization;
public class TwoVarInferenceModel {
public interface TwoVarInferenceListener {
void setStatTable(int row, String[] rowNames, int length,
String[] columnNames);
void setFormattedValueAt(double value, int row, int col);
GeoList getDataSelected();
int getSelectedDataIndex(int idx);
double[] getValueArray(GeoList list);
void addAltHypItem(String name, String tail, double value);
void selectAltHyp(int idx);
}
public interface UpdatePanel {
void updatePanel();
}
private int selectedInference = StatisticsModel.INFER_TINT_2MEANS;
// test type (tail)
private static final String tail_left = "<";
private static final String tail_right = ">";
private static final String tail_two = ExpressionNodeConstants.strNOT_EQUAL;
private String tail = tail_two;
// input fields
private double confLevel = .95, hypMean = 0;
// statistics
double t, P, df, lower, upper, se, me, n1, n2, diffMeans, mean1,
mean2;
private TTest tTestImpl;
private TDistribution tDist;
private boolean pooled = false;
private double meanDifference;
private TwoVarInferenceListener listener;
private Localization loc;
/**
* Construct a TwoVarInference panel
*/
public TwoVarInferenceModel(App app, TwoVarInferenceListener listener) {
this.loc = app.getLocalization();
this.listener = listener;
}
public boolean isPairedData() {
return (selectedInference == StatisticsModel.INFER_TINT_PAIRED
|| selectedInference == StatisticsModel.INFER_TTEST_PAIRED);
}
public String getNullHypName() {
if (selectedInference == StatisticsModel.INFER_TTEST_2MEANS) {
return loc.getMenu("DifferenceOfMeans.short");
} else if (selectedInference == StatisticsModel.INFER_TTEST_PAIRED) {
return loc.getMenu("MeanDifference");
} else {
return "";
}
}
public boolean isTest() {
return (selectedInference == StatisticsModel.INFER_TTEST_2MEANS
|| selectedInference == StatisticsModel.INFER_TTEST_PAIRED);
}
public void setResults() {
ArrayList<String> list = new ArrayList<String>();
switch (selectedInference) {
default:
// do nothing
break;
case StatisticsModel.INFER_TTEST_2MEANS:
case StatisticsModel.INFER_TTEST_PAIRED:
if (selectedInference == StatisticsModel.INFER_TTEST_PAIRED) {
list.add(loc.getMenu("MeanDifference"));
} else {
list.add(loc.getMenu("fncInspector.Difference"));
}
list.add(loc.getMenu("PValue"));
list.add(loc.getMenu("TStatistic"));
list.add(loc.getMenu("StandardError.short"));
list.add(loc.getMenu("DegreesOfFreedom.short"));
break;
case StatisticsModel.INFER_TINT_2MEANS:
case StatisticsModel.INFER_TINT_PAIRED:
if (selectedInference == StatisticsModel.INFER_TINT_PAIRED) {
list.add(loc.getMenu("MeanDifference"));
} else {
list.add(loc.getMenu("fncInspector.Difference"));
}
list.add(loc.getMenu("MarginOfError.short"));
list.add(loc.getMenu("LowerLimit"));
list.add(loc.getMenu("UpperLimit"));
list.add(loc.getMenu("StandardError.short"));
list.add(loc.getMenu("DegreesOfFreedom.short"));
break;
}
String[] columnNames = new String[list.size()];
list.toArray(columnNames);
listener.setStatTable(1, null, columnNames.length, columnNames);
}
public void updateResults() {
boolean ok = evaluate();
if (!ok) {
return;
}
switch (selectedInference) {
default:
// do nothing
break;
case StatisticsModel.INFER_TTEST_2MEANS:
case StatisticsModel.INFER_TTEST_PAIRED:
if (selectedInference == StatisticsModel.INFER_TTEST_PAIRED) {
listener.setFormattedValueAt(meanDifference, 0, 0);
} else {
listener.setFormattedValueAt(diffMeans, 0, 0);
}
listener.setFormattedValueAt(P, 0, 1);
listener.setFormattedValueAt(t, 0, 2);
listener.setFormattedValueAt(se, 0, 3);
listener.setFormattedValueAt(df, 0, 4);
break;
case StatisticsModel.INFER_TINT_2MEANS:
case StatisticsModel.INFER_TINT_PAIRED:
if (selectedInference == StatisticsModel.INFER_TINT_PAIRED) {
listener.setFormattedValueAt(meanDifference, 0, 0);
} else {
listener.setFormattedValueAt(diffMeans, 0, 0);
}
listener.setFormattedValueAt(me, 0, 1);
listener.setFormattedValueAt(lower, 0, 2);
listener.setFormattedValueAt(upper, 0, 3);
listener.setFormattedValueAt(se, 0, 4);
listener.setFormattedValueAt(df, 0, 5);
break;
}
}
// ============================================================
// Evaluate
// ============================================================
public boolean evaluate() {
// get the sample data
GeoList dataCollection = listener.getDataSelected();
GeoList dataList1 = (GeoList) dataCollection
.get(listener.getSelectedDataIndex(0));
double[] sample1 = listener.getValueArray(dataList1);
SummaryStatistics stats1 = new SummaryStatistics();
for (int i = 0; i < sample1.length; i++) {
stats1.addValue(sample1[i]);
}
GeoList dataList2 = (GeoList) dataCollection
.get(listener.getSelectedDataIndex(1));
double[] sample2 = listener.getValueArray(dataList2);
SummaryStatistics stats2 = new SummaryStatistics();
for (int i = 0; i < sample2.length; i++) {
stats2.addValue(sample2[i]);
}
// exit if paired data is expected and sample sizes are unequal
if (isPairedData() && stats1.getN() != stats2.getN()) {
return false;
}
if (tTestImpl == null) {
tTestImpl = new TTest();
}
double tCritical;
try {
switch (selectedInference) {
default:
// do nothing
break;
case StatisticsModel.INFER_TTEST_2MEANS:
case StatisticsModel.INFER_TINT_2MEANS:
// get statistics
mean1 = StatUtils.mean(sample1);
mean2 = StatUtils.mean(sample2);
diffMeans = mean1 - mean2;
n1 = stats1.getN();
n2 = stats2.getN();
double v1 = stats1.getVariance();
double v2 = stats2.getVariance();
df = getDegreeOfFreedom(v1, v2, n1, n2, isPooled());
if (isPooled()) {
double pooledVariance = ((n1 - 1) * v1 + (n2 - 1) * v2)
/ (n1 + n2 - 2);
se = Math.sqrt(pooledVariance * (1d / n1 + 1d / n2));
} else {
se = Math.sqrt((v1 / n1) + (v2 / n2));
}
// get confidence interval
tDist = new TDistribution(df);
tCritical = tDist.inverseCumulativeProbability(
(getConfLevel() + 1d) / 2);
me = tCritical * se;
upper = diffMeans + me;
lower = diffMeans - me;
// get test results
if (isPooled()) {
t = tTestImpl.homoscedasticT(sample1, sample2);
P = tTestImpl.homoscedasticTTest(sample1, sample2);
} else {
t = tTestImpl.t(sample1, sample2);
P = tTestImpl.tTest(sample1, sample2);
}
P = adjustedPValue(P, t, tail);
break;
case StatisticsModel.INFER_TTEST_PAIRED:
case StatisticsModel.INFER_TINT_PAIRED:
// get statistics
n1 = sample1.length;
meanDifference = StatUtils.meanDifference(sample1, sample2);
se = Math.sqrt(StatUtils.varianceDifference(sample1, sample2,
meanDifference) / n1);
df = n1 - 1;
tDist = new TDistribution(df);
tCritical = tDist.inverseCumulativeProbability(
(getConfLevel() + 1d) / 2);
me = tCritical * se;
upper = meanDifference + me;
lower = meanDifference - me;
// get test results
t = meanDifference / se;
P = 2.0 * tDist.cumulativeProbability(-Math.abs(t));
P = adjustedPValue(P, t, tail);
break;
}
} catch (RuntimeException e) {
// catches ArithmeticException, IllegalStateException and
// ArithmeticException
e.printStackTrace();
return false;
}
return true;
}
// TODO: Validate !!!!!!!!!!!
public double adjustedPValue(double p, double testStatistic, String tail) {
// two sided test
if (tail.equals(tail_two)) {
return p;
} else if ((tail.equals(tail_right) && testStatistic > 0)
|| (tail.equals(tail_left) && testStatistic < 0)) {
return p / 2;
} else {
return 1 - p / 2;
}
}
/**
* Computes approximate degrees of freedom for 2-sample t-estimate. (code
* from Apache commons, TTest class)
*
* @param v1
* first sample variance
* @param v2
* second sample variance
* @param n1
* first sample n
* @param n2
* second sample n
* @return approximate degrees of freedom
*/
public double getDegreeOfFreedom(double v1, double v2, double n1, double n2,
boolean pooled) {
if (pooled) {
return n1 + n2 - 2;
}
return (((v1 / n1) + (v2 / n2)) * ((v1 / n1) + (v2 / n2)))
/ ((v1 * v1) / (n1 * n1 * (n1 - 1d))
+ (v2 * v2) / (n2 * n2 * (n2 - 1d)));
}
/**
* Computes margin of error for 2-sample t-estimate; this is the half-width
* of the confidence interval
*
* @param v1
* first sample variance
* @param v2
* second sample variance
* @param n1
* first sample n
* @param n2
* second sample n
* @param confLevel
* confidence level
* @return margin of error for 2 mean interval estimate
* @throws ArithmeticException
*/
public double getMarginOfError(double v1, double n1, double v2, double n2,
double confLevel, boolean pooled) throws ArithmeticException {
if (pooled) {
double pooledVariance = ((n1 - 1) * v1 + (n2 - 1) * v2)
/ (n1 + n2 - 2);
double se1 = Math.sqrt(pooledVariance * (1d / n1 + 1d / n2));
tDist = new TDistribution(
getDegreeOfFreedom(v1, v2, n1, n2, pooled));
double a = tDist.inverseCumulativeProbability((confLevel + 1d) / 2);
return a * se1;
}
double se = Math.sqrt((v1 / n1) + (v2 / n2));
tDist = new TDistribution(
getDegreeOfFreedom(v1, v2, n1, n2, pooled));
double a = tDist.inverseCumulativeProbability((confLevel + 1d) / 2);
return a * se;
}
public void setSelectedInference(int value) {
selectedInference = value;
}
public double getHypMean() {
return hypMean;
}
public void setHypMean(double hypMean) {
this.hypMean = hypMean;
}
public double getConfLevel() {
return confLevel;
}
public void setConfLevel(double confLevel) {
this.confLevel = confLevel;
}
public boolean isPooled() {
return pooled;
}
public void applyTail(int idx) {
if (idx == 0) {
tail = tail_right;
} else if (idx == 1) {
tail = tail_left;
} else {
tail = tail_two;
}
updateResults();
}
public void setPooled(boolean pooled) {
this.pooled = pooled;
updateResults();
}
public void fillAlternateHyp() {
String nullHypName = getNullHypName();
listener.addAltHypItem(nullHypName, tail_right, hypMean);
listener.addAltHypItem(nullHypName, tail_left, hypMean);
listener.addAltHypItem(nullHypName, tail_two, hypMean);
if (tail == tail_right) {
listener.selectAltHyp(0);
} else if (tail == tail_left) {
listener.selectAltHyp(0);
} else {
listener.selectAltHyp(2);
}
}
}