/* * Copyright (C) 2012 Jason Gedge <http://www.gedge.ca> * * This file is part of the OpGraph project. * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ package ca.gedge.opgraph.nodes.general; import static org.junit.Assert.*; import org.junit.Test; import ca.gedge.opgraph.InputField; import ca.gedge.opgraph.OpContext; import ca.gedge.opgraph.OpGraph; import ca.gedge.opgraph.OpNode; import ca.gedge.opgraph.OutputField; import ca.gedge.opgraph.Processor; import ca.gedge.opgraph.exceptions.ProcessingException; import ca.gedge.opgraph.nodes.general.ConstantValueNode; import ca.gedge.opgraph.nodes.general.MacroNode; import ca.gedge.opgraph.nodes.general.PassThroughNode; import ca.gedge.opgraph.nodes.logic.LogicalNotNode; /** * Tests {@link MacroNode}. */ public class TestMacroNode { static class AddNode extends OpNode { public final static InputField X_FIELD = new InputField("x", "", false, true, Double.class); public final static InputField Y_FIELD = new InputField("y", "", false, true, Double.class); public final static OutputField RESULT_FIELD = new OutputField("result", "", true, Double.class); public AddNode() { super("Add", "Computes x + y"); putField(X_FIELD); putField(Y_FIELD); putField(RESULT_FIELD); } @Override public void operate(OpContext context) { double x = (Double)context.get(X_FIELD); double y = (Double)context.get(Y_FIELD); context.put(RESULT_FIELD, x + y); } } static class MultiplyNode extends OpNode { public final static InputField X_FIELD = new InputField("x", "", false, true, Double.class); public final static InputField Y_FIELD = new InputField("y", "", true, true, Double.class); public final static OutputField RESULT_FIELD = new OutputField("result", "", true, Double.class); public MultiplyNode() { super("Multiply", "Computes x*y"); putField(X_FIELD); putField(Y_FIELD); putField(RESULT_FIELD); } @Override public void operate(OpContext context) { double x = (Double)context.get(X_FIELD); double y = 1.0; if(context.containsKey(Y_FIELD)) y = (Double)context.get(Y_FIELD); context.put(RESULT_FIELD, x * y); } } static class LessThanNode extends OpNode { public final static InputField X_FIELD = new InputField("x", "", false, true, Double.class); public final static InputField Y_FIELD = new InputField("y", "", false, true, Double.class); public final static OutputField RESULT_FIELD = new OutputField("result", "", true, Boolean.class); public LessThanNode() { super("Less Than", "Computes x < y"); putField(X_FIELD); putField(Y_FIELD); putField(RESULT_FIELD); } @Override public void operate(OpContext context) { double x = (Double)context.get(X_FIELD); double y = (Double)context.get(Y_FIELD); context.put(RESULT_FIELD, x < y); } } /** * Processes an operable graph. * * @param graph the graph to process * @param context the operating context, or <code>null</code> to use a default one * * @return the operating context used for processing * * @throws ProcessingException if any errors occurred during processing */ public static OpContext process(OpGraph graph, OpContext context) throws ProcessingException { final Processor processor = new Processor(graph); processor.reset(context); processor.stepAll(); if(processor.getError() != null) throw processor.getError(); return processor.getContext(); } /** * Gets a result from the execution of a graph. * * @param cls the type of result * @param graph the graph to execute * @param context the operating context, or <code>null</code> to use a default one * @param resultNode the node containing the result * @param resultField the field containing the result * * @return the result * * @throws ProcessingException if any errors occurred during processing */ public static <T> T getResult(Class<T> cls, OpGraph graph, OpContext context, OpNode resultNode, OutputField resultField) throws ProcessingException { return cls.cast(process(graph, context).findChildContext(resultNode).get(resultField)); } /** * Constructs an operable graph that computes the minimum of two values. * * @param inputs an array for returning the two nodes for inputs * @param outputs an array for returning the single node containing the output * * @return the graph */ private static OpGraph createMinDAG(PassThroughNode[] inputs, PassThroughNode[] outputs) { // // Constructs a dag that computes the minimum of two values, making // use of the ENABLED_FIELD feature of OpNode // final OpGraph minDAG = new OpGraph(); minDAG.setId("min"); final PassThroughNode pt1_1 = new PassThroughNode(); final PassThroughNode pt2_1 = new PassThroughNode(); final PassThroughNode pt1_2 = new PassThroughNode(); final PassThroughNode pt2_2 = new PassThroughNode(); final PassThroughNode ov1 = new PassThroughNode(); final LessThanNode lv1 = new LessThanNode(); final LogicalNotNode nv1 = new LogicalNotNode(); // Return values inputs[0] = pt1_1; inputs[1] = pt2_1; outputs[0] = ov1; // Add nodes minDAG.add(pt1_1); minDAG.add(pt1_2); minDAG.add(pt2_1); minDAG.add(pt2_2); minDAG.add(lv1); minDAG.add(nv1); minDAG.add(ov1); // Add link assertNotNull(minDAG.connect(pt1_1, PassThroughNode.OUTPUT, lv1, LessThanNode.X_FIELD)); assertNotNull(minDAG.connect(pt2_1, PassThroughNode.OUTPUT, lv1, LessThanNode.Y_FIELD)); assertNotNull(minDAG.connect(pt1_1, PassThroughNode.OUTPUT, pt1_2, PassThroughNode.INPUT)); assertNotNull(minDAG.connect(pt2_1, PassThroughNode.OUTPUT, pt2_2, PassThroughNode.INPUT)); assertNotNull(minDAG.connect(lv1, LessThanNode.RESULT_FIELD, pt1_2, OpNode.ENABLED_FIELD)); assertNotNull(minDAG.connect(lv1, LessThanNode.RESULT_FIELD, nv1, LogicalNotNode.X_INPUT_FIELD)); assertNotNull(minDAG.connect(nv1, LogicalNotNode.RESULT_OUTPUT_FIELD, pt2_2, OpNode.ENABLED_FIELD)); assertNotNull(minDAG.connect(pt1_2, PassThroughNode.OUTPUT, ov1, PassThroughNode.INPUT)); assertNotNull(minDAG.connect(pt2_2, PassThroughNode.OUTPUT, ov1, PassThroughNode.INPUT)); return minDAG; } /** Tests the correctness of a macro */ @Test public void testMacro() { PassThroughNode [] inputs1 = new PassThroughNode[2]; PassThroughNode [] outputs1 = new PassThroughNode[1]; PassThroughNode [] inputs2 = new PassThroughNode[2]; PassThroughNode [] outputs2 = new PassThroughNode[1]; OpGraph dag = new OpGraph(); OpGraph minDAG1 = createMinDAG(inputs1, outputs1); OpGraph minDAG2 = createMinDAG(inputs2, outputs2); ConstantValueNode cv1 = new ConstantValueNode(1.0); ConstantValueNode cv2 = new ConstantValueNode(2.0); ConstantValueNode cv3 = new ConstantValueNode(3.0); MacroNode min1 = new MacroNode(minDAG1); MacroNode min2 = new MacroNode(minDAG2); dag.add(cv1); dag.add(cv2); dag.add(cv3); dag.add(min1); dag.add(min2); // Publish inputs/outputs from macros InputField min1_in1 = min1.publish("x", inputs1[0], PassThroughNode.INPUT); InputField min1_in2 = min1.publish("y", inputs1[1], PassThroughNode.INPUT); OutputField min1_out1 = min1.publish("result", outputs1[0], PassThroughNode.OUTPUT); InputField min2_in1 = min2.publish("x", inputs2[0], PassThroughNode.INPUT); InputField min2_in2 = min2.publish("y", inputs2[1], PassThroughNode.INPUT); OutputField min2_out1 = min2.publish("result", outputs2[0], PassThroughNode.OUTPUT); try { assertNotNull(dag.connect(cv1, cv1.VALUE_OUTPUT_FIELD, min1, min1_in1)); assertNotNull(dag.connect(cv2, cv2.VALUE_OUTPUT_FIELD, min1, min1_in2)); assertNotNull(dag.connect(min1, min1_out1, min2, min2_in1)); assertNotNull(dag.connect(cv3, cv3.VALUE_OUTPUT_FIELD, min2, min2_in2)); for(int i = 0; i < 5; ++i) { for(int j = 0; j < 5; ++j) { for(int k = 0; k < 5; ++k) { double minVal = Math.min(Math.min(i, j), k); cv1.setValue(1.0*i); cv2.setValue(1.0*j); cv3.setValue(1.0*k); double result = getResult(Double.class, dag, null, min2, min2_out1); assertEquals(minVal, result, 1e-10); } } } } catch(ProcessingException exc) { if(exc.getCause() != null) exc.getCause().printStackTrace(); else exc.printStackTrace(); fail("Should be no errors when processing"); } } }