/*
* To change this template, choose Tools | Templates
* and open the template in the editor.
*/
package wordlengthoptimization;
import datapath.graph.Graph;
import datapath.graph.display.dot.DotDisplayFactory;
import datapath.graph.operations.Operation;
import datapath.graph.operations.ParentInput;
import datapath.graph.operations.ParentOutput;
import datapath.graph.type.FixedPoint;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
/**
* Performs one pass of forward propagation type of wordlength optimization.
* The algorithm goes once top down over graph (it has to be acyclic) and
* computes all the wordlength and precision at the outputs from the wordlength
* and precision at the inputs.
*
* @author fs
*/
public class ForwardPropagation implements WordlengthOptimization {
private Options opts;
private Graph graph;
HashMap<Operation, Double> minValues = new HashMap<Operation, Double>();
HashMap<Operation, Double> maxValues = new HashMap<Operation, Double>();
ArrayList<HashMap<ParentInput, Double>> monteCarloInputsFloat = new ArrayList<HashMap<ParentInput, Double>>();
ArrayList<HashMap<ParentInput, BigInteger>> monteCarloInputsFix = new ArrayList<HashMap<ParentInput, BigInteger>>();
ArrayList<HashMap<Operation, Double>> monteCarloResultsFloat = new ArrayList<HashMap<Operation, Double>>();
ArrayList<HashMap<Operation, BigInteger>> monteCarloResultsFix = new ArrayList<HashMap<Operation, BigInteger>>();
/*
* adds to startValues the other Values from monte carlo computation */
private HashMap<Operation, Double> performMonteCarloIter(HashMap<ParentInput, Double> startValues) {
ComputeValueVisitor computeValue = new ComputeValueVisitor(startValues);
Operation.nextVisit();
for (ParentOutput outputNode : graph.getOutput()) {
outputNode.postOrderUpwardVisit(computeValue);
}
return computeValue.getValues();
}
private void initStartVariableRanges() {
/* add the ranges from the pragmas
Important: ALL INPUTS MUST HAVE SPECIFIED RANGES VIA PRAGMAS */
for (ParentInput input : graph.getInput()) {
String strMin = opts.getStartVariableMinValues().get(input.getName());
String strMax = opts.getStartVariableMaxValues().get(input.getName());
minValues.put(input, Double.parseDouble(strMin));
maxValues.put(input, Double.parseDouble(strMax));
}
}
private HashMap<ParentInput, Double> generateTrace(double pwx, double pwy, double pwz) {
HashMap<ParentInput, Double> traceValue = new HashMap<ParentInput, Double>();
for (ParentInput input : graph.getInput()) {
if (input.getName().equals("pwx"))
traceValue.put(input, 0.0);
if (input.getName().equals("pwy"))
traceValue.put(input, 1.4);
if (input.getName().equals("pwz"))
traceValue.put(input, -1.0);
}
return traceValue;
}
private void generateMonteCarloFloatInputs() {
/* first we add a trace value */
monteCarloInputsFloat.add(generateTrace(0.0, 1.4, -1.0));
for (int iter = 0; iter < opts.getMonteCarloIterations(); iter++) {
HashMap<ParentInput, Double> startValues = new HashMap<ParentInput, Double>();
// init with random values
for (ParentInput input : graph.getInput()) {
double minValue = minValues.get(input);
double maxValue = maxValues.get(input);
double range = maxValue - minValue;
double randomValue = minValue + Math.random() * range;
startValues.put(input, randomValue);
}
monteCarloInputsFloat.add(startValues);
}
}
private void generateMonteCarloFixInputsFromFloats() {
/* the float inputs shold have not higher precision as the fixed inputs => so recreate them */
ArrayList<HashMap<ParentInput, Double>> newFloatValues = new ArrayList<HashMap<ParentInput, Double>>();
for (HashMap<ParentInput, Double> floatValues : monteCarloInputsFloat) {
HashMap<ParentInput, BigInteger> startValues = new HashMap<ParentInput, BigInteger>();
HashMap<ParentInput, Double> newFloatStartValues = new HashMap<ParentInput, Double>();
for (ParentInput input : floatValues.keySet()) {
FixedPoint fpt = (FixedPoint) input.getType();
startValues.put(input, Util.fixedPointFromFloat(floatValues.get(input), fpt.getFractionlength()));
newFloatStartValues.put(input, Util.floatFromfixedPoint(startValues.get(input), fpt.getFractionlength()));
}
monteCarloInputsFix.add(startValues);
newFloatValues.add(newFloatStartValues);
}
monteCarloInputsFloat = newFloatValues;
}
private void generateMonteCarloInputs() {
initStartVariableRanges();
generateMonteCarloFloatInputs();
generateMonteCarloFixInputsFromFloats();
}
private void monteCarloFix() {
int i = 0;
for (HashMap<ParentInput, BigInteger> startValues : monteCarloInputsFix) {
HashMap<Operation, BigInteger> result;
Operation.nextVisit();
ComputeIntegerValueVisitor fixpointValue = new ComputeIntegerValueVisitor(startValues);
try {
for (ParentOutput outputNode : graph.getOutput()) {
outputNode.postOrderUpwardVisit(fixpointValue);
}
monteCarloResultsFix.add(fixpointValue.getValues());
} catch (java.lang.ArithmeticException arithException) {
// catch divide by zero, neg square which can happen due to invalid random input
System.err.println("div by zero dataset: " + startValues);
/* debug for annotatting dot graph with values
for (Operation op : fixpointValue.getValues().keySet()) {
op.setDebugMessage(fixpointValue.getValues().get(op).toString());
monteCarloResultsFloat = new ...
return
} */
//System.err.println(fixpointValue.getValues());
monteCarloResultsFloat.remove(i);
i--;
}
i++;
}
}
private void monteCarloFloat() {
boolean firstRun = true;
generateMonteCarloInputs();
for (HashMap<ParentInput, Double> startValues : monteCarloInputsFloat) {
HashMap<Operation, Double> result = performMonteCarloIter((startValues));
monteCarloResultsFloat.add(result);
/* merge the fixpoint runs for value range analysis */
if (firstRun) {
/* not so easy, get the input min and max back into the first result */
HashMap<Operation, Double> tmp = (HashMap<Operation, Double>)result.clone();
tmp.putAll(maxValues);
maxValues = tmp;
tmp = ((HashMap<Operation, Double>) result.clone());
tmp.putAll(minValues);
minValues = tmp;
firstRun = false;
} else {
Util.merge(maxValues, result, Util.MAXMERGE);
Util.merge(minValues, result, Util.MINMERGE);
}
}
/*
for (Operation op: minValues.keySet()) {
System.out.println("Variable " + op.getDebugMessage() + " " + op.getDisplayName() +
" MC-Value Range: [" + minValues.get(op) + ", " +
maxValues.get(op) + "]");
} */
}
private void monteCarloErrorAnalysis() {
HashMap<Operation, Double> maxAbsError = new HashMap<Operation, Double>();
HashMap<Operation, Double> maxRelError = new HashMap<Operation, Double>();
Iterator<HashMap<Operation, BigInteger>> iter = monteCarloResultsFix.iterator();
for (HashMap<Operation, Double> floatValues : monteCarloResultsFloat) {
HashMap<Operation, BigInteger> fixValues = iter.next();
for (Operation op : floatValues.keySet()) {
FixedPoint fpt = (FixedPoint) op.getType();
double floatValue = floatValues.get(op);
double fixValue = Util.floatFromfixedPoint(fixValues.get(op), fpt.getFractionlength());
double newDifference = Math.abs(floatValue - fixValue);
if (!maxAbsError.containsKey(op) || maxAbsError.get(op) < newDifference) {
maxAbsError.put(op, newDifference);
maxRelError.put(op, newDifference/floatValue);
}
}
}
for (Operation op : maxAbsError.keySet()) {
//op.setDebugMessage(op.getDebugMessage() + " Error: " + maxAbsError.get(op));
if (!op.getDebugMessage().isEmpty()) {
System.out.println("Difference in node " + op.getDebugMessage() +
"=> Abs: " + maxAbsError.get(op) +
" Rel: " + maxRelError.get(op));
}
}
}
private void traceValues() {
HashMap<Operation, Double> floatVals = monteCarloResultsFloat.get(0);
for (Operation op : monteCarloResultsFix.get(0).keySet()) {
BigInteger val = monteCarloResultsFix.get(0).get(op);
double doubleval = (floatVals.get(op) == null) ? 0 : floatVals.get(op);
FixedPoint fp = (FixedPoint) op.getType();
if (!op.getDebugMessage().isEmpty()) {
System.out.println(op.getDebugMessage() + " " +op + " BigIntVal: " + val + " (" +
Util.floatFromfixedPoint(val, fp.getFractionlength()) +
") Float: " + doubleval);
}
op.setDebugMessage(op.getDebugMessage() +
"\\nBigInteger: " + val + " (" + Util.floatFromfixedPoint(val, fp.getFractionlength()) + ")" +
"\\nFloat: " + doubleval +
"\\nError: " + Math.abs(doubleval - Util.floatFromfixedPoint(val, fp.getFractionlength())));
}
}
@Override
public int optimize(Graph graph) {
int changed = 0;
this.graph = graph;
/* monte carlo simulation */
monteCarloFloat();
/* apply forward propagation recursivle to all nodes */
Operation.nextVisit();
ForwardPropagationVisitorNewTypeCast forward = new ForwardPropagationVisitorNewTypeCast(minValues, maxValues);
for (ParentOutput outputNode : graph.getOutput()) {
outputNode.postOrderUpwardVisit(forward);
}
/* restrict all nodes to the maximum wordlength */
Operation.nextVisit();
LimitBitwidthNewTypeCast limit = new LimitBitwidthNewTypeCast(opts.getMaxWordlength(), opts.getMinFractionlength());
for (ParentOutput outputNode : graph.getOutput()) {
outputNode.postOrderUpwardVisit(limit);
}
System.out.println(limit.getStats());
/* insert shift if necessary between nodes */
Operation.nextVisit();
ShiftInserterNewTypeCast shifts = new ShiftInserterNewTypeCast(graph);
for (ParentOutput outputNode : graph.getOutput()) {
outputNode.postOrderUpwardVisit(shifts);
}
/* perform fixed point simulation */
monteCarloFix();
monteCarloErrorAnalysis();
traceValues();
graph.display(new DotDisplayFactory(),"beforetypecast");
/* translate TypeConversions into appropiate shifts/bitselects */
RemoveTypeConversion rtc = new RemoveTypeConversion();
rtc.removeTypeConversions(graph);
return forward.getChanged();
}
@Override
public String toString() {
return "Forward Propagation";
}
@Override
public void setOptions(Options opts) {
this.opts = opts;
}
}