package datapath.graph;
import datapath.graph.modlib.Module;
import datapath.graph.modlib.Wire;
import datapath.graph.modlib.WireAnd;
import datapath.graph.modlib.WireConcat;
import datapath.graph.modlib.WireIO;
import datapath.graph.modlib.WireNot;
import datapath.graph.modlib.WireOR;
import datapath.graph.modlib.parameter.*;
import datapath.graph.operations.*;
import datapath.graph.type.*;
import java.io.BufferedWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
* Backend to generate the Verilog output for each {@link Graph}.
* @author jh
*/
public class ModlibWriter implements OperationVisitor {
public static boolean mul_pipe = true;
public static boolean mul_pipe_create = true;
private Graph graph;
private BufferedWriter writer;
private Wire CE = new Wire("CE");
private Wire CLK = new Wire("CLK");
private Wire RESET = new Wire("RESET");
private Wire INIT = new Wire("INIT");
private Wire END = new Wire("END");
private HashSet<Module> modules;
private ModlibWriter(Graph graph, BufferedWriter writer) {
this.writer = writer;
modules = new HashSet<Module>();
}
private WireIO getWireIO(Operation source, String port) {
Wire w = getWire(source);
return new WireIO(w, port);
}
HashMap<Operation, Wire> wireSources = new HashMap();
private Wire getWire(Operation source) {
Wire w = wireSources.get(source);
if (w == null) {
w = new Wire(getDataWire(source));
w.setSize(source.getOutputBitsize());
wireSources.put(source, w);
}
return w;
}
int combinedDepth;
int highestDepth;
private DEPTH getDepth(Operation op) {
int depth = -1;
for (Operation use : op.getUse()) {
if (depth == -1) {
depth = Graph.getDistance(op, use);
}
assert Graph.getDistance(op, use) == depth : "different depths -> possible scheduling error " +
op;
}
depth = Math.max(depth, 0); // for operations which have no use
combinedDepth += depth;
if(!(op instanceof LoopInit))
highestDepth = Math.max(highestDepth, depth);
return new DEPTH(depth);
}
/**
* Returns the string representation of result wire for the given Operation.
* @param source the operation for which the result wire is desired
* @return the string representation of result wire for the given Operation
*/
public static String getDataWire(Operation source) {
return String.format("dw_op%d", source.getNumber());
}
private Wire getControlWireAnd(Set<Predicate> predicates) {
ArrayList<Wire> wires = new ArrayList<Wire>();
for (Predicate p : predicates) {
wires.add(getControlWire(p));
}
return new WireAnd(wires.toArray(new Wire[0]));
}
private Wire getControlWireOR(Set<Predicate> predicates) {
ArrayList<Wire> wires = new ArrayList<Wire>();
for (Predicate p : predicates) {
wires.add(getControlWire(p));
}
return new WireOR(wires.toArray(new Wire[0]));
}
private Wire getControlWire(Predicate pred) {
assert pred.getData().getOutputBitsize() == 1;
switch (pred.getPredicationType()) {
case INIT:
case TRUE:
return getWire(pred.getData());
case FALSE:
return new WireNot(getWire(pred.getData()));
default:
throw new RuntimeException("don't know the predicate type");
}
}
private Wire getPredicationWire(Predication pred) {
ArrayList<Wire> w = new ArrayList<Wire>();
for (Iterator<Predicate> iter = pred.getPredicates().iterator(); iter.hasNext();) {
Predicate p = iter.next();
w.add(getControlWire(p));
}
return new WireAnd(w.toArray(new Wire[0]));
}
private static void writeRecursive(Graph graph, BufferedWriter writer) throws IOException {
for (Graph g : graph.getInnerLoops()) {
writeRecursive(g, writer);
}
ModlibWriter w = new ModlibWriter(graph, writer);
w.write("`timescale 1ns / 1ns\n");
w.write(graph);
}
private void write(Graph graph) throws IOException {
combinedDepth = 0;
highestDepth = 0;
mulpipes = new HashSet<MulPipeCreator.Options>();
// generate all modules and wires
for (Operation op : graph.getOperations()) {
if (!op.isHardwareOperation()) {
continue;
}
op.visit(this);
}
write("// new loop\n");
write(String.format("// latest schedule is %d\n",
graph.getLatestSchedule()));
write(String.format("// combined depth: %d\n", combinedDepth));
write(String.format("// highest depth: %d\n", highestDepth));
int numAdd = graph.numOfOperation(true, datapath.graph.operations.Add.class);
int numSub = graph.numOfOperation(true, datapath.graph.operations.Subtraction.class);
int numMul = graph.numOfOperation(true, datapath.graph.operations.Multiplication.class);
int numDiv = graph.numOfOperation(true, datapath.graph.operations.Divide.class);
int numCos = graph.numOfOperation(true, datapath.graph.operations.Cos.class);
int numSin = graph.numOfOperation(true, datapath.graph.operations.Sin.class);
int numSqrt = graph.numOfOperation(true, datapath.graph.operations.SquareRoot.class);
int numTotal = numAdd + numSub + numMul + numDiv + numCos + numSin + numSqrt;
write("// number of operations\n");
write(String.format("// add: %d\n", numAdd));
write(String.format("// sub: %d\n", numSub));
write(String.format("// div: %d\n", numDiv));
write(String.format("// mul: %d\n", numMul));
write(String.format("// cos: %d\n", numCos));
write(String.format("// sin: %d\n", numSin));
write(String.format("// sqrt: %d\n", numSqrt));
write(String.format("// total: %d\n", numTotal));
write("module graph" + graph.getId() + "\n");
write("(\n");
// Control Wires
write("input wire RESULT_ACCEPT,\n");
write("input wire CANCEL,\n");
write("input wire CANCEL_STATE_RESET,\n");
write("input wire START,\n");
write("input wire START_CTRL,\n");
write("input wire " + CLK.withSize() + ",\n");
write("input wire " + RESET.withSize() + ",\n");
write("input wire " + INIT.withSize() + ",\n");
write("output wire " + END.withSize() + ",\n");
write("input wire " + CE.withSize());
// data input of the graph
for (ParentInput pin : graph.getInput()) {
write(",\n");
write("input wire " + getWire(pin.getSource()).withSize());
wireSources.remove(pin.getSource());
}
// data output of the graph
for (ParentOutput pout : graph.getOutput()) {
write(",\n");
write("output wire " + getWire(pout).withSize());
wireSources.remove(pout);
}
write("\n);\n");
for (Wire w : wireSources.values()) {
writeWire(w);
write("\n");
}
addCommonControl();
for (Module m : modules) {
write(Modlib.module(m));
write("\n");
}
write("endmodule\n");
if(mul_pipe_create && mul_pipe) {
System.out.println("number of different mul_pipe:" + mulpipes.size());
System.out.println(mulpipes);
MulPipeCreator.c(mulpipes);
}
}
/**
* Writes the Verilog output of given graph and all its subgraphs to the
* given writer.
* @param g the graph to write
* @param writer the writer with which the graph is printed
*/
public static void write(Graph g, BufferedWriter writer) {
try {
writeRecursive(g, writer);
writer.flush();
} catch (IOException ex) {
Logger.getLogger(ModlibWriter.class.getName()).log(Level.SEVERE,
null, ex);
}
}
@Override
public void visit(Operation op) {
System.err.println(op.getClass() + " not supported");
}
@Override
public void visit(BinaryOperation op) {
//throw new UnsupportedOperationException("Not supported yet.");
}
@Override
public void visit(Mux op) {
Set<Operation> preds = op.getOperands();
ArrayList<Wire> Aw = new ArrayList<Wire>();
ArrayList<Wire> Bw = new ArrayList<Wire>();
for (Iterator<Operation> iter = preds.iterator(); iter.hasNext();) {
Predication pred = (Predication) iter.next();
Aw.add(getWire(pred.getData()));
Bw.add(getPredicationWire(pred));
}
Module m = new Module("mux", op.getNumber());
m.addParameter(new WR(op.getOutputBitsize()));
m.addParameter(new WA(op.getOutputBitsize()));
m.addParameter(new SIGN(op.isSigned()));
m.addParameter(getDepth(op));
m.addParameter(new NIN(Aw.size()));
m.addIO(getWireIO(op, "R"));
m.addIO(new WireIO(new WireConcat(Aw.toArray(new Wire[0])), "A"));
m.addIO(new WireIO(new WireConcat(Bw.toArray(new Wire[0])), "B"));
m.addIO(new WireIO(new WireOR(Bw.toArray(new Wire[0])), "START"));
modules.add(m);
}
@Override
public void visit(ConstantOperation op) {
Module m = new Module("const_op", op.getNumber());
m.addParameter(new WR(op.getOutputBitsize()));
// m.addParameter(new VALUE(op.getValue().replace("0x", op.getOutputBitsize() +
// "'h").replace("U",
// "")));
m.addParameter(new VALUE(String.format("%d'h%s", op.getOutputBitsize(),
op.toHex())));
// the data path works without this
// but it is inserted so that the output changes with the end signal
DEPTH depth = getDepth(op);
if(false && depth.getV() > 0) {
Nop n = new Nop();
n.setType(op.getType());
Module nop = new Module("nop", n.getNumber());
m.addIO(getWireIO(n, "R"));
nop.addIO(getWireIO(n, "A"));
nop.addIO(getWireIO(op, "R"));
nop.addParameter(new WR(op.getOutputBitsize()));
nop.addParameter(new WA(op.getOutputBitsize()));
nop.addParameter(depth);
modules.add(nop);
// end
} else {
m.addIO(getWireIO(op, "R"));
}
modules.add(m);
}
@Override
public void visit(Add op) {
Type type = op.getType();
if (type instanceof datapath.graph.type.Integer) {
binaryOp("add", op);
} else if (type instanceof datapath.graph.type.FixedPoint) {
binaryOp("add", op);
} else if (type instanceof datapath.graph.type.Float) {
binaryOp("addfloat", op);
} else {
throw new UnsupportedOperationException("Not supported yet.");
}
}
private void writeWire(Wire w) {
write("wire " + w.withSize() + ";");
}
private void write(String s) {
try {
writer.write(s);
} catch (IOException ex) {
Logger.getLogger(ModlibWriter.class.getName()).log(Level.SEVERE,
null, ex);
}
}
private void addCommonControl() {
for (Module m : modules) {
if(m.getType().equals("const_op")) {
m.addParameter(new QDEPTH(1));
} else
m.addParameter(new QDEPTH(0));
m.addIO(new WireIO(CLK, "CLK"));
m.addIO(new WireIO(CE, "CE"));
m.addIO(new WireIO(new Wire("1'b1"), "START"));
m.addIO(new WireIO(new Wire("1'b1"), "RESULT_ACCEPT"));
m.addIO(new WireIO(new Wire("1'b0"), "CANCEL"));
m.addIO(new WireIO(new Wire("1'b1"), "START_CTRL"));
m.addIO(new WireIO(new Wire("1'b1"), "CANCEL_STATE_RESET"));
m.addIO(new WireIO(new Wire("1'b1"), "CANCEL_STATE_CTRL_RESET"));
m.addIO(new WireIO(RESET, "RESET"));
}
}
@Override
public void visit(Less op) {
binaryOp("cmplt", op);
}
private void binaryOp(String name, BinaryOperation op) {
Module m = new Module(name, op.getNumber());
m.addParameter(new WA(op.getLhs().getOutputBitsize()));
m.addParameter(new WB(op.getRhs().getOutputBitsize()));
m.addParameter(new WR(op.getOutputBitsize()));
m.addParameter(new SIGN(op.isSigned()));
m.addParameter(getDepth(op));
m.addIO(getWireIO(op.getLhs(), "A"));
m.addIO(getWireIO(op.getRhs(), "B"));
m.addIO(getWireIO(op, "R"));
modules.add(m);
}
private void unaryOp(String name, UnaryOperation op) {
Module m = new Module(name, op.getNumber());
m.addParameter(new WA(op.getData().getOutputBitsize()));
m.addParameter(new WR(op.getOutputBitsize()));
m.addParameter(new SIGN(op.isSigned()));
m.addParameter(getDepth(op));
m.addIO(getWireIO(op.getData(), "A"));
m.addIO(getWireIO(op, "R"));
modules.add(m);
}
@Override
public void visit(MemWrite op) {
Module m = new Module("memwrite", op.getNumber());
m.addParameter(new ADDR_WIDTH(op.getAddress().getOutputBitsize()));
m.addParameter(new DATA_WIDTH(op.getData().getOutputBitsize()));
m.addIO(getWireIO(op.getAddress(), "A"));
m.addIO(getWireIO(op.getData(), "B"));
Wire start = getControlWireAnd(op.getPredicates());
m.addIO(new WireIO(start, "START"));
modules.add(m);
}
@Override
public void visit(FromOuterLoop op) {
Module m = new Module("nop", "FromOuterLoop", op.getNumber());
m.addParameter(new WR(op.getOutputBitsize()));
m.addParameter(new WA(op.getSource().getOutputBitsize()));
m.addParameter(new DEPTH(0));
m.addIO(getWireIO(op, "R"));
m.addIO(getWireIO(op.getSource(), "A"));
modules.add(m);
}
@Override
public void visit(ToInnerLoop op) {
Module m = new Module("nop", op.getNumber());
m.addParameter(new WR(op.getOutputBitsize()));
m.addIO(getWireIO(op.getData(), "A"));
m.addIO(getWireIO(op, "R"));
modules.add(m);
}
@Override
public void visit(Loop op) {
Module m = new Module("graph" + op.getGraph().getId(),
op.getGraph().getId());
for (ParentInput pin : op.getGraph().getInput()) {
m.addIO(getWireIO(pin.getSource(), getDataWire(pin.getSource())));
}
for (ParentOutput pout : op.getGraph().getOutput()) {
m.addIO(getWireIO(pout, getDataWire(pout)));
}
modules.add(m);
}
@Override
public void visit(HWInput op) {
Module m = new Module("inreg", op.getNumber());
m.addParameter(new WR(op.getOutputBitsize()));
m.addIO(getWireIO(op, "R"));
//modules.add(m);
}
@Override
public void visit(VariableShift op) {
String type;
switch (op.getMode()) {
case Left:
type = "vsl";
break;
default:
throw new UnsupportedOperationException("Not supported yet.");
}
Module m = new Module(type, op.getNumber());
m.addParameter(new WR(op.getOutputBitsize()));
m.addParameter(new WA(op.getLhs().getOutputBitsize()));
m.addParameter(new WB(op.getRhs().getOutputBitsize()));
m.addParameter(getDepth(op));
m.addIO(getWireIO(op.getLhs(), "A"));
m.addIO(getWireIO(op.getRhs(), "B"));
m.addIO(getWireIO(op, "R"));
modules.add(m);
}
@Override
public void visit(ConstantShift op) {
Module m;
ShiftMode mode = op.getMode();
if (mode == ShiftMode.ZeroShiftLeft) {
op.setShiftAmount(0);
op.setMode(ShiftMode.Left);
}
if (mode == ShiftMode.ZeroShiftRight) {
op.setShiftAmount(0);
op.setMode(ShiftMode.Right);
}
mode = op.getMode();
switch (mode) {
case Right:
if(op.isSigned()){
mode = ShiftMode.SignedRight;
} else
mode = ShiftMode.UnsignedRight;
}
switch (mode) {
case Left:
m = new Module("lsl", op.getNumber());
m.addParameter(new WB(op.getShiftAmount()));
break;
case SignedRight:
assert op.getShiftAmount() >= 0;
m = new Module("vsr", op.getNumber());
m.addParameter(new WB(8));
m.addParameter(new SIGN(true));
m.addIO(new WireIO(new Wire("8'd" + op.getShiftAmount()), "B"));
break;
case UnsignedRight:
m = new Module("lsr", op.getNumber());
m.addParameter(new WB(op.getShiftAmount()));
break;
default:
throw new UnsupportedOperationException("Not supported yet.");
}
m.addParameter(new WA(op.getData().getOutputBitsize()));
// WB is set above
m.addParameter(new WR(op.getOutputBitsize()));
m.addParameter(getDepth(op));
m.addIO(getWireIO(op.getData(), "A"));
m.addIO(getWireIO(op, "R"));
modules.add(m);
}
@Override
public void visit(Nop op) {
Module m = new Module("nop", op.getNumber());
m.addParameter(new WR(op.getOutputBitsize()));
m.addParameter(new WA(op.getData().getOutputBitsize()));
m.addParameter(new DEPTH(Graph.getDistance(op,
op.getUse().iterator().next())));
m.addIO(getWireIO(op.getData(), "A"));
m.addIO(getWireIO(op, "R"));
modules.add(m);
}
@Override
public void visit(ToOuterLoop op) {
Module m = new Module("nop", "ToOuterLoop", op.getNumber());
m.addParameter(new WR(op.getOutputBitsize()));
m.addParameter(new WA(op.getData().getOutputBitsize()));
m.addParameter(new DEPTH(1));
m.addIO(getWireIO(op, "R"));
m.addIO(getWireIO(op.getData(), "A"));
m.addIO(new WireIO(getControlWireOR(op.getPredicates()), "START"));
modules.add(m);
}
@Override
public void visit(LoopEnd op) {
Module m = new Module("nop", "LoopEnd", op.getNumber());
m.addParameter(new WA(1)); // control wire has only one bit
m.addParameter(new WR(1));
m.addParameter(new DEPTH(1));
m.addIO(new WireIO(getControlWireAnd(op.getPredicates()), "A"));
m.addIO(new WireIO(END, "R"));
modules.add(m);
}
@Override
public void visit(LoopInit op) {
Module m = new Module("nop", "LoopInit", op.getNumber());
m.addParameter(new WA(1));
m.addParameter(new WR(op.getOutputBitsize()));
m.addParameter(getDepth(op));
m.addIO(getWireIO(op, "R"));
m.addIO(new WireIO(INIT, "A"));
modules.add(m);
}
@Override
public void visit(HWOutput op) {
Module m = new Module("nop", "HWOutput", op.getNumber());
m.addParameter(new WA(op.getData().getOutputBitsize()));
m.addParameter(new WR(op.getOutputBitsize()));
m.addParameter(new DEPTH(1));
m.addIO(getWireIO(op, "R"));
m.addIO(getWireIO(op.getData(), "A"));
m.addIO(new WireIO(getControlWireOR(op.getPredicates()), "START"));
modules.add(m);
}
@Override
public void visit(TopLevelInput op) {
Module m = new Module("nop", "TopLevelInput", op.getNumber());
m.addParameter(new WA(op.getSource().getOutputBitsize()));
m.addParameter(new WR(op.getOutputBitsize()));
m.addParameter(getDepth(op));
m.addIO(getWireIO(op, "R"));
m.addIO(getWireIO(op.getSource(), "A"));
modules.add(m);
}
@Override
public void visit(Negation op) {
Type type = op.getType();
Module m;
if (type instanceof datapath.graph.type.Integer) {
m = new Module("sub", op.getNumber());
} else if (type instanceof datapath.graph.type.FixedPoint) {
m = new Module("sub", op.getNumber());
} else if (type instanceof datapath.graph.type.Float) {
m = new Module("subfloat", op.getNumber());
} else {
throw new UnsupportedOperationException("Not supported yet.");
}
m.addParameter(new WA(op.getData().getOutputBitsize()));
m.addParameter(new WB(op.getData().getOutputBitsize()));
m.addParameter(new WR(op.getOutputBitsize()));
m.addParameter(new SIGN(op.isSigned()));
assert op.getUse().size() > 0 : op + " has no use";
m.addParameter(new DEPTH(Graph.getDistance(op,
op.getUse().iterator().next())));
m.addIO(getWireIO(op.getData(), "B"));
m.addIO(new WireIO(new Wire(String.format("%d'd0", op.getData().getOutputBitsize())), "A"));
m.addIO(getWireIO(op, "R"));
modules.add(m);
}
HashSet<MulPipeCreator.Options> mulpipes;
@Override
public void visit(Multiplication op) {
if(op.getLhs().getOutputBitsize() < op.getRhs().getOutputBitsize()) {
Operation lhs = op.getLhs();
Operation rhs = op.getRhs();
op.removeLHS();
op.removeRHS();
op.setLHS(rhs);
op.setRHS(lhs);
}
int lhs = op.getLhs().getOutputBitsize();
int rhs = op.getRhs().getOutputBitsize();
int bit = op.getOutputBitsize();
Type type = op.getType();
if (type instanceof datapath.graph.type.Integer) {
binaryOp("mul", op);
} else if (type instanceof datapath.graph.type.FixedPoint) {
if(mul_pipe && lhs <= 64 && rhs <= 64 && bit <= 128) {
if(mul_pipe_create) {
MulPipeCreator.Options ops = new MulPipeCreator.Options();
ops.WA = lhs;
ops.WB = rhs;
ops.WR = bit;
ops.signed = op.isSigned();
ops.stages = op.getDelay();
mulpipes.add(ops);
}
binaryOp("mul_pipe", op);
} else {
binaryOp("mul", op);
}
} else if (type instanceof datapath.graph.type.Float) {
binaryOp("mulfloat", op);
} else {
throw new UnsupportedOperationException("Not supported yet.");
}
}
@Override
public void visit(Subtraction op) {
Type type = op.getType();
if (type instanceof datapath.graph.type.Integer) {
binaryOp("sub", op);
} else if (type instanceof datapath.graph.type.FixedPoint) {
binaryOp("sub", op);
} else if (type instanceof datapath.graph.type.Float) {
binaryOp("subfloat", op);
} else {
throw new UnsupportedOperationException("Not supported yet.");
}
}
@Override
public void visit(Divide op) {
Type type = op.getType();
if (type instanceof datapath.graph.type.Integer) {
binaryOp("divint", op);
} else if (type instanceof datapath.graph.type.FixedPoint) {
if(op.getOutputBitsize() == 64) {
binaryOp("bigdiv", op);
} else {
binaryOp("divint", op);
}
} else if (type instanceof datapath.graph.type.Float) {
binaryOp("divfloat", op);
} else {
throw new UnsupportedOperationException("Not supported yet.");
}
}
@Override
public void visit(Absolut op) {
Type type = op.getType();
Module m = null;
if(type instanceof datapath.graph.type.Integer) {
m = new Module("abs", op.getNumber());
} else if (type instanceof datapath.graph.type.FixedPoint) {
m = new Module("abs", op.getNumber());
} else {
throw new UnsupportedOperationException("Not supported yet.");
}
m.addParameter(new WA(op.getData().getOutputBitsize()));
m.addParameter(new WR(op.getOutputBitsize()));
m.addParameter(new SIGN(op.getData().isSigned()));
m.addParameter(getDepth(op));
m.addIO(getWireIO(op.getData(), "A"));
m.addIO(getWireIO(op, "R"));
modules.add(m);
}
@Override
public void visit(Sin op) {
throw new UnsupportedOperationException("Not supported yet.");
}
@Override
public void visit(Cos op) {
throw new UnsupportedOperationException("Not supported yet.");
}
@Override
public void visit(ArcCos op) {
throw new UnsupportedOperationException("Not supported yet.");
}
@Override
public void visit(SquareRoot op) {
Type type = op.getType();
if (type instanceof datapath.graph.type.Integer) {
unaryOp("sqrtint", op);
} else if (type instanceof datapath.graph.type.FixedPoint) {
unaryOp("sqrtint", op);
} else if (type instanceof datapath.graph.type.Float) {
throw new UnsupportedOperationException("Not supported yet.");
} else {
throw new UnsupportedOperationException("Not supported yet.");
}
}
@Override
public void visit(BitwidthTransmogrify op) {
Module m = new Module("bitsel", op.getNumber());
m.addParameter(new WA(op.getData().getOutputBitsize()));
m.addParameter(new WR(op.getOutputBitsize()));
m.addParameter(new SIGN(op.isSigned()));
m.addParameter(getDepth(op));
m.addIO(getWireIO(op.getData(), "A"));
m.addIO(getWireIO(op, "R"));
modules.add(m);
}
@Override
public void visit(Predicate op) {
// do nothing
}
@Override
public void visit(TypeConversion op) {
throw new UnsupportedOperationException("Not supported.");
}
@Override
public void visit(ConstantMultiplication op) {
binaryOp("mul", op);
}
}