/* * Copyright (c) 2013, SRI International * All rights reserved. * Licensed under the The BSD 3-Clause License; * you may not use this file except in compliance with the License. * You may obtain a copy of the License at: * * http://opensource.org/licenses/BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions * are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the aic-util nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE * COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED * OF THE POSSIBILITY OF SUCH DAMAGE. */ package com.sri.ai.util.experiment; import java.util.Arrays; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Map; import com.google.common.annotations.Beta; import com.sri.ai.util.Util; import com.sri.ai.util.gnuplot.DataSeries; import com.sri.ai.util.gnuplot.Gnuplot; import com.sri.ai.util.rangeoperation.api.DAEFunction; import com.sri.ai.util.rangeoperation.api.DependencyAwareEnvironment; import com.sri.ai.util.rangeoperation.api.Range; import com.sri.ai.util.rangeoperation.api.RangeOperation; import com.sri.ai.util.rangeoperation.core.AbstractDAEFunction; import com.sri.ai.util.rangeoperation.core.RangeOperationsInterpreter; import com.sri.ai.util.rangeoperation.library.rangeoperations.Dimension; @Beta public class Experiment { public static String[] guaranteedPreCommands = { "set xlabel font 'Arial, 10", "set ylabel font 'Arial, 10'", "set title font 'Arial, 20'" }; /** * Runs an experiment and generates a gnuplot with its results. * <p> * The method works by generating a data matrix (according to {@link RangeOperationsInterpreter}) * with one or two {@link Dimension}s. * The plot may contain one or more data series (shown as lines on the plot). * A {@link DataSeriesSpec} object defines which {@link Dimension} is used to define multiple data series, * as well as data series titles and format. * <p> * There is a commented example at the end of this documentation block. * <p> * Arguments to this method can be of three possible types: * <ul> * <li> String values representing variable names * followed by another argument representing the variable's value. * These variables can either be gnuplot parameters (see reference below), * or fixed parameters stored in a {@link DependencyAwareEnvironment} to be used * by the experiment (via {@link DAEFunction}s -- see below). * <li> {@link RangeOperation}s, which work like for-loops (creating data matrix {@link Dimension}s) * or aggregate operations for a specified variable. * <li> a single occurrence of a {@link DAEFunction}; this is the main argument since * this function is responsible for computing the experiment's reported results. * It can run whatever code one wishes, including calls to * {@link DependencyAwareEnvironment#get(String)}, * {@link DependencyAwareEnvironment#getOrUseDefault(String, Object)} and * {@link DependencyAwareEnvironment#getResultOrRecompute(DAEFunction)} * to gain access to the variables defined by other arguments * (both fixed and iterated by range operations). * The interaction between range operations and a DAEFunction * produces a generalized (multidimensional) matrix (see {@link RangeOperationsInterpreter} for details). * In {@link Experiment}, the range operations and DAEFunction must be structured * so that the resulting matrix is a regular rows-and-columns matrix * in which each row is the data for an individual data series to appear in the plot. * <li> a single {@link DataSeriesSpec}, which declares gnuplot directives for each of * the data series (the several graph lines) in the graph ({@link DataSeries}). * We obtain one data series per value of a variable specified by * the {@link DataSeriesSpec}. * </ul> * Arguments can be provided in any order, although the order of range operations does matter * (their iterations are nested in the order they are given). * <p> * The variable on the graph's x-axis is the one specified by the first {@link Dimension} argument * that is <i>not</i> the variable labeling the multiple data series. * <p> * The recognized gnuplot parameters are the following: * <ul> * <li> title: the graph's title * <li> xlabel: the label of the x axis (default is variable name in x-{@link Dimension}). * <li> ylabel: the label of the y axis (default is empty string). * <li> filename: the name of a file (without an extension) to which to record the graph; * the extension ".ps" is automatically added. * <li> file: boolean value indicating whether to record the graph, using title as filename. * <li> print: same as file * </ul> * If file recording is disabled, gnuplot persists (keeps open) and shows the graph; * otherwise, it closes as soon as the graph is recorded. * <p> * Consider the example: * <pre> * experiment( * "file", "false", * "title", "The more samples, the less variance", * "xlabel", "Number of samples", * "ylabel", "Average of Uniform[0,1] + mean - 0.5", * "some unused variable that could be used if we wanted to", 10, * new {@link Dimension}("mean", Util.list(2, 3)), * new {@link Dimension}("numSamples", 1, 1000, 1), * averageOfNumSamplesOfUniformPlusCurrentMeanMinusZeroPointFive, * {@link DataSeriesSpec}("mean", Util.list( * Util.list("title 'mean 2'", "w linespoints"), * Util.list("title 'mean 3'", "w linespoints")))); * </pre> * Here, some gnuplot parameters and a (unused) variable are introduced. * Then ranging operations {@link Dimension} is used to vary variables "mean" and "numSamples" * across a range of values. * The function averageOfNumSamplesOfUniformPlusCurrentMeanMinusZeroPointFive uses them to compute * the elements of a matrix. * The {@link DataSeriesSpec} specifies that the "mean" dimension is the one determining * the individual data series for the plot * (and as a consequence the "numSamples" dimension is selected for the x-axis), * and specifies their gnuplot labels and styles as well. * Note that the plot's y-axis does not correspond to any dimensions of the matrix, but to its values * (the ones computed by the {@link DAEFunction}). * * @param arguments * the experiments arguments. */ public static void experiment(Object ... arguments) { List<Dimension<Object>> dimensions = getDimensions(arguments); DataSeriesSpec dataSeriesDimensionSpec = getDataSeriesDimensionSpec(arguments); Range<Object> xSeries = getXDimensionRange(dimensions, dataSeriesDimensionSpec); List data = (List) RangeOperationsInterpreter.apply(arguments); List<DataSeries<Object>> dataSeriesList = getDataSeriesList(data, dimensions, dataSeriesDimensionSpec); Map<String, Object> properties = Util.getMapWithStringKeys(arguments); if ( ! properties.containsKey("xlabel")) { final Dimension<Object> xDimension = getXDimension(dimensions, dataSeriesDimensionSpec); String name = xDimension == null? "x" : xDimension.getRange().getName(); properties.put("xlabel", name); } LinkedList<String> preCommands = getPreCommands(properties); Gnuplot.plot(preCommands, xSeries, dataSeriesList); } private static LinkedList<String> getPreCommands(Map<String, Object> properties) { LinkedList<String> preCommands = new LinkedList<String>(); for (String preCommand : guaranteedPreCommands) { preCommands.add(preCommand); } if (properties.containsKey("title")) { preCommands.add("set title '" + properties.get("title") + "'"); } if (properties.containsKey("xlabel")) { preCommands.add("set xlabel '" + properties.get("xlabel") + "'"); } if (properties.containsKey("ylabel")) { preCommands.add("set ylabel '" + properties.get("ylabel") + "'"); } if (writesToFile(properties)) { preCommands.add("set term postscript color"); preCommands.add("set output '" + filename(properties) + ".ps'"); } else { preCommands.add("persist"); } return preCommands; } private static boolean writesToFile(Map<String, Object> properties) { boolean result = properties.containsKey("filename") || Util.getOrUseDefault(properties, "print", "false").equals("true") || Util.getOrUseDefault(properties, "file", "false").equals("true"); return result; } private static String filename(Map<String, Object> properties) { String filename = (String) Util.getOrUseDefault(properties, (String) properties.get("filename"), properties.get("title")); if (filename == null) { filename = "unnamed"; } return filename; } private static <T> List<Dimension<T>> getDimensions(Object ... arguments) { List<Dimension<T>> result = new LinkedList<Dimension<T>>(); for (Object object : arguments) { if (object instanceof Dimension<?>) { @SuppressWarnings("unchecked") Dimension<T> dimension = (Dimension<T>) object; result.add(dimension); } } return result; } /** * A class indicating the variable corresponding to a data series in a graph, * as well as its directives (see {@link Gnuplot}). */ public static class DataSeriesSpec { private String variable; private List<List<String>> directivesList; public DataSeriesSpec(String variable, List<List<String>> directivesList) { this.variable = variable; this.directivesList = directivesList; } public String getName() { return variable; } } private static DataSeriesSpec getDataSeriesDimensionSpec(Object... arguments) { return (DataSeriesSpec) Util.getObjectOfClass(DataSeriesSpec.class, arguments); } private static <T> Range<T> getXDimensionRange(List<Dimension<T>> dimensions, DataSeriesSpec dataSeriesDimensionSpec) { Dimension<T> xDimension = getXDimension(dimensions, dataSeriesDimensionSpec); final Range<T> result = xDimension != null ? xDimension.getRange() : null; return result; } private static <T> Dimension<T> getXDimension(List<Dimension<T>> dimensions, DataSeriesSpec dataSeriesDimensionSpec) { for (Dimension<T> dimension : dimensions) { if ( ! dimension.getRange().getName().equals(dataSeriesDimensionSpec.variable)) { return dimension; } } return null; } private static <T> Dimension<T> getDataSeriesDimension(List<Dimension<T>> dimensions, DataSeriesSpec dataSeriesSpec) { for (Dimension<T> dimension : dimensions) { if (dimension.getRange().getName().equals(dataSeriesSpec.variable)) { return dimension; } } return null; } private static <T> List<DataSeries<T>> getDataSeriesList(List data, List<Dimension<T>> dimensions, DataSeriesSpec dataSeriesSpec) { List<DataSeries<T>> dataSeriesList = new LinkedList<DataSeries<T>>(); Dimension<T> dataSeriesDimension = getDataSeriesDimension(dimensions, dataSeriesSpec); if (dataSeriesDimension != null) { int dimension = dimensions.indexOf(dataSeriesDimension); Iterator rangeIterator = dataSeriesDimension.getRange().apply(); Iterator<List<String>> directiveIterator = dataSeriesSpec.directivesList.iterator(); int sliceIndex = 0; while(rangeIterator.hasNext()) { rangeIterator.next(); if ( ! directiveIterator.hasNext()) { throw new Error("DataSeriesSpec on '" + dataSeriesSpec.getName() + "' does not have enough directives (it needs one per value of '" + dataSeriesSpec.getName() + "')"); } List<String> directives = directiveIterator.next(); @SuppressWarnings("unchecked") List<T> dataSeriesData = Util.matrixSlice((List<List<T>>) data, dimension, sliceIndex); dataSeriesList.add(new DataSeries<T>(directives, dataSeriesData)); sliceIndex++; } } else { if (dimensions.size() > 1) { Util.fatalError("DataSeriesSpec " + dataSeriesSpec + " does not refer to any present dimension and data is multidimensional."); } List<String> directives = Util.getFirst(dataSeriesSpec.directivesList); @SuppressWarnings("unchecked") List<T> dataList = data; dataSeriesList.add(new DataSeries<T>(directives, dataList)); } return dataSeriesList; } /** * An extension of {@link List} for keeping pre-commands for a {@link Gnuplot} graph. */ @SuppressWarnings("serial") private static class PreCommands extends LinkedList<String> { public PreCommands(String ... preCommands) { addAll(Arrays.asList(preCommands)); } } public static PreCommands preCommands(String ... preCommands) { return new PreCommands(preCommands); } public static PreCommands getPreCommands(Object ... arguments) { PreCommands preCommands = (PreCommands) Util.getObjectOfClass(PreCommands.class, arguments); if (preCommands == null) { preCommands = new PreCommands(); } return preCommands; } private static class Title { public Title(String value) { buffer.append(value); } @Override public String toString() { return buffer.toString(); } private StringBuffer buffer = new StringBuffer(); } public static Title Title(String value) { return new Title(value); } public static String getTitle(Object ... args) { return Util.getObjectOfClass(Title.class, args).toString(); } private static DAEFunction averageOfNumSamplesOfUniformPlusCurrentMeanMinusZeroPointFive = new AbstractDAEFunction() { @Override public Object apply(DependencyAwareEnvironment environment) { int numberOfSamples = environment.getInt("numSamples"); int mean = environment.getInt("mean"); double sum = 0; for (int i = 0; i != numberOfSamples; i++) { sum += (mean - 0.5) + Math.random(); } double result = sum/numberOfSamples; return result; } }; @SuppressWarnings("unchecked") public static void main(String[] args) { experiment( "file", "false", "title", "The more samples, the less variance", "xlabel", "Number of samples", "ylabel", "Average of Uniform[0,1] + mean - 0.5", "some unused variable that could be used if we wanted to", 10, new Dimension("mean", Util.list(2, 3)), new Dimension("numSamples", 1, 200, 1), averageOfNumSamplesOfUniformPlusCurrentMeanMinusZeroPointFive, new DataSeriesSpec("mean", Util.list( Util.list("title 'mean 2'", "w linespoints"), Util.list("title 'mean 3'", "w linespoints")))); } }