/* Copyright 2009-2016 David Hadka
*
* This file is part of the MOEA Framework.
*
* The MOEA Framework is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, either version 3 of the License, or (at your
* option) any later version.
*
* The MOEA Framework is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
* or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
* License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with the MOEA Framework. If not, see <http://www.gnu.org/licenses/>.
*/
package org.moeaframework.analysis.sensitivity;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.junit.Assert;
import org.junit.Test;
import org.moeaframework.TestThresholds;
import org.moeaframework.TestUtils;
import org.moeaframework.util.sequence.Saltelli;
/**
* Tests the {@link SobolAnalysis} class.
*/
public class SobolAnalysisTest {
@Test
public void testNoInteraction1() throws Exception {
File outputFile = test(new MultivariateFunction() {
@Override
public double value(double[] variables) {
return variables[0];
}
});
assertEntryEquals(outputFile, "Variable1", 0, 1.0);
assertEntryEquals(outputFile, "Variable2", 0, 0.0);
assertEntryEquals(outputFile, "Variable3", 0, 0.0);
assertEntryEquals(outputFile, "Variable1", 1, 1.0);
assertEntryEquals(outputFile, "Variable2", 1, 0.0);
assertEntryEquals(outputFile, "Variable3", 1, 0.0);
assertEntryEquals(outputFile, "Variable1 \\* Variable2", 0, 0.0);
assertEntryEquals(outputFile, "Variable1 \\* Variable3", 0, 0.0);
assertEntryEquals(outputFile, "Variable2 \\* Variable3", 0, 0.0);
}
@Test
public void testNoInteraction12() throws Exception {
File outputFile = test(new MultivariateFunction() {
@Override
public double value(double[] variables) {
return variables[0] + variables[1];
}
});
assertEntryEquals(outputFile, "Variable1", 0, 0.5);
assertEntryEquals(outputFile, "Variable2", 0, 0.5);
assertEntryEquals(outputFile, "Variable3", 0, 0.0);
assertEntryEquals(outputFile, "Variable1", 1, 0.5);
assertEntryEquals(outputFile, "Variable2", 1, 0.5);
assertEntryEquals(outputFile, "Variable3", 1, 0.0);
assertEntryEquals(outputFile, "Variable1 \\* Variable2", 0, 0.0);
assertEntryEquals(outputFile, "Variable1 \\* Variable3", 0, 0.0);
assertEntryEquals(outputFile, "Variable2 \\* Variable3", 0, 0.0);
}
@Test
public void testNoInteraction123() throws Exception {
File outputFile = test(new MultivariateFunction() {
@Override
public double value(double[] variables) {
return variables[0] + variables[1] + variables[2];
}
});
assertEntryEquals(outputFile, "Variable1", 0, 0.333);
assertEntryEquals(outputFile, "Variable2", 0, 0.333);
assertEntryEquals(outputFile, "Variable3", 0, 0.333);
assertEntryEquals(outputFile, "Variable1", 1, 0.333);
assertEntryEquals(outputFile, "Variable2", 1, 0.333);
assertEntryEquals(outputFile, "Variable3", 1, 0.333);
assertEntryEquals(outputFile, "Variable1 \\* Variable2", 0, 0.0);
assertEntryEquals(outputFile, "Variable1 \\* Variable3", 0, 0.0);
assertEntryEquals(outputFile, "Variable2 \\* Variable3", 0, 0.0);
}
@Test
public void testInteraction12() throws Exception {
File outputFile = test(new MultivariateFunction() {
@Override
public double value(double[] variables) {
return variables[0]*variables[1] + variables[2];
}
});
assertEntryNotEquals(outputFile, "Variable1", 0, 0.0);
assertEntryEquals(outputFile, "Variable1", 0,
getEntryValue(outputFile, "Variable2", 0));
assertEntryNotEquals(outputFile, "Variable3", 0, 0.0);
assertEntryEquals(outputFile, "Variable1", 1,
getEntryValue(outputFile, "Variable1", 0) +
getEntryValue(outputFile, "Variable1 \\* Variable2", 0));
assertEntryEquals(outputFile, "Variable2", 1,
getEntryValue(outputFile, "Variable2", 0) +
getEntryValue(outputFile, "Variable1 \\* Variable2", 0));
assertEntryEquals(outputFile, "Variable3", 1,
getEntryValue(outputFile, "Variable3", 0));
assertEntryNotEquals(outputFile, "Variable1 \\* Variable2", 0, 0.0);
assertEntryEquals(outputFile, "Variable1 \\* Variable3", 0, 0.0);
assertEntryEquals(outputFile, "Variable2 \\* Variable3", 0, 0.0);
}
@Test
public void testInteraction123() throws Exception {
File outputFile = test(new MultivariateFunction() {
@Override
public double value(double[] variables) {
return variables[0]*variables[1]*variables[2];
}
});
assertEntryNotEquals(outputFile, "Variable1", 0, 0.0);
assertEntryEquals(outputFile, "Variable1", 0,
getEntryValue(outputFile, "Variable2", 0));
assertEntryEquals(outputFile, "Variable1", 0,
getEntryValue(outputFile, "Variable3", 0));
assertEntryEquals(outputFile, "Variable1", 1,
getEntryValue(outputFile, "Variable1", 0) +
getEntryValue(outputFile, "Variable1 \\* Variable2", 0) +
getEntryValue(outputFile, "Variable1 \\* Variable3", 0));
assertEntryEquals(outputFile, "Variable2", 1,
getEntryValue(outputFile, "Variable2", 0) +
getEntryValue(outputFile, "Variable1 \\* Variable2", 0) +
getEntryValue(outputFile, "Variable2 \\* Variable3", 0));
assertEntryEquals(outputFile, "Variable3", 1,
getEntryValue(outputFile, "Variable3", 0) +
getEntryValue(outputFile, "Variable1 \\* Variable3", 0) +
getEntryValue(outputFile, "Variable2 \\* Variable3", 0));
assertEntryNotEquals(outputFile, "Variable1 \\* Variable2", 0, 0.0);
assertEntryEquals(outputFile, "Variable1 \\* Variable2", 0,
getEntryValue(outputFile, "Variable1 \\* Variable3", 0));
assertEntryEquals(outputFile, "Variable1 \\* Variable2", 0,
getEntryValue(outputFile, "Variable2 \\* Variable3", 0));
}
/**
* Runs Sobol analysis on the given function.
*
* @param function the function to evaluate
* @return the file containing the output from Sobol analysis
* @throws FunctionEvaluationException if an error occurred when evaluating
* the function
* @throws IOException if an I/O error occurred
*/
protected File test(MultivariateFunction function) throws Exception {
double[][] input = new Saltelli().generate(1000*8, 3);
double[] output = evaluate(function, input);
File outputFile = TestUtils.createTempFile();
File parameterFile = TestUtils.createTempFile();
File inputFile = TestUtils.createTempFile();
createParameterFile(parameterFile, 3);
save(inputFile, output);
SobolAnalysis.main(new String[] {
"--parameterFile", parameterFile.getPath(),
"--input", inputFile.getPath(),
"--metric", "0",
"--output", outputFile.getPath()
});
return outputFile;
}
/**
* Asserts that an entry in the Sobol results is equals an expected value.
*
* @param file the file containing the output from Sobol analysis
* @param key the regular expression for identifying the desired entry in
* the output file
* @param skip set to {@code 0} for first-order or second-order effects,
* {@code 1} for total-order effects
* @param expected the expected sensitivity value
* @throws IOException if an I/O error occurred
*/
protected void assertEntryEquals(File file, String key, int skip,
double expected) throws IOException {
Assert.assertEquals(expected, getEntryValue(file, key, skip),
TestThresholds.STATISTICS_EPS);
}
/**
* Asserts that an entry in the Sobol results is not equal to an expected
* value.
*
* @param file the file containing the output from Sobol analysis
* @param key the regular expression for identifying the desired entry in
* the output file
* @param skip set to {@code 0} for first-order or second-order effects,
* {@code 1} for total-order effects
* @param expected the expected sensitivity value
* @throws IOException if an I/O error occurred
*/
protected void assertEntryNotEquals(File file, String key, int skip,
double expected) throws IOException {
Assert.assertTrue(Math.abs(expected - getEntryValue(file, key, skip)) >
TestThresholds.STATISTICS_EPS);
}
/**
* Parses the specified Sobol output file, returning the sensitivity for
* the given key.
*
* @param file the file containing the output from Sobol analysis
* @param key the regular expression for identifying the desired entry in
* the output file
* @param skip set to {@code 0} for first-order or second-order effects,
* {@code 1} for total-order effects
* @return the sensitivity for the given key
* @throws IOException if an I/O error occurred
*/
protected double getEntryValue(File file, String key, int skip) throws
IOException {
BufferedReader reader = null;
String line = null;
Pattern pattern = Pattern.compile("^\\s*" + key + "\\s+" +
TestUtils.getSpaceSeparatedNumericPattern(1) + "\\s+\\[" +
TestUtils.getSpaceSeparatedNumericPattern(1) + "\\]\\s*$");
try {
reader = new BufferedReader(new FileReader(file));
while ((line = reader.readLine()) != null) {
Matcher matcher = pattern.matcher(line);
if (matcher.matches()) {
if (skip-- == 0) {
return Double.parseDouble(matcher.group(1));
}
}
}
return Double.POSITIVE_INFINITY;
} finally {
if (reader != null) {
reader.close();
}
}
}
/**
* Evaluates the given function on each input entry, returning the array of
* outputs.
*
* @param function the function
* @param input the array of inputs, where the outer index references each
* set of input to the function
* @return the array of outputs
* @throws FunctionEvaluationException if an error occurred when evaluating
* the function
*/
protected double[] evaluate(MultivariateFunction function,
double[][] input) {
double[] output = new double[input.length];
for (int i=0; i<input.length; i++) {
output[i] = function.value(input[i]);
}
return output;
}
/**
* Saves the test data to the input file compatible with the input to the
* Sobol analysis utility.
*
* @param file the file to create
* @param data the array of model outputs
* @throws IOException if an I/O error occurred
*/
protected void save(File file, double[] data) throws IOException {
PrintWriter writer = null;
try {
writer = new PrintWriter(file);
for (int i=0; i<data.length; i++) {
writer.println(data[i]);
}
} finally {
if (writer != null) {
writer.close();
}
}
}
/**
* Creates the parameter file with entries of the form
* <pre>
* VariableN 0.0 1.0
* </pre>
*
* @param file the parameter file to create
* @param dimension the number of variables
* @throws IOException if an I/O error occurred
*/
protected void createParameterFile(File file, int dimension) throws
IOException {
PrintWriter writer = null;
try {
writer = new PrintWriter(file);
for (int i=0; i<dimension; i++) {
writer.print("Variable");
writer.print(i+1);
writer.print(' ');
writer.print(0.0);
writer.print(' ');
writer.println(1.0);
}
} finally {
if (writer != null) {
writer.close();
}
}
}
}