/*
* Copyright 2015 S. Webber
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.oakgp.node.walk;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertSame;
import static org.oakgp.TestUtils.assertNodeEquals;
import static org.oakgp.TestUtils.createVariable;
import static org.oakgp.TestUtils.integerConstant;
import static org.oakgp.TestUtils.readNode;
import static org.oakgp.function.math.IntegerUtils.INTEGER_UTILS;
import static org.oakgp.node.NodeType.isConstant;
import static org.oakgp.node.NodeType.isFunction;
import static org.oakgp.node.NodeType.isVariable;
import java.util.function.Function;
import java.util.function.Predicate;
import org.junit.Test;
import org.oakgp.node.ConstantNode;
import org.oakgp.node.FunctionNode;
import org.oakgp.node.Node;
import org.oakgp.node.VariableNode;
public class NodeWalkTest {
@Test
public void testReplaceAt_VariableNode() {
final VariableNode v = createVariable(0);
final ConstantNode c = integerConstant(Integer.MAX_VALUE);
assertSame(v, NodeWalk.replaceAt(v, 0, t -> t));
assertSame(c, NodeWalk.replaceAt(v, 0, t -> c));
}
@Test
public void testReplaceAll_VariableNode() {
final VariableNode v = createVariable(0);
final ConstantNode c = integerConstant(Integer.MAX_VALUE);
Function<Node, Node> replacement = n -> c;
assertSame(c, NodeWalk.replaceAll(v, n -> n == v, replacement));
assertSame(v, NodeWalk.replaceAll(v, n -> n == c, replacement));
}
@Test
public void testGet_VariableNode() {
final VariableNode v = createVariable(0);
assertSame(v, NodeWalk.getAt(v, 0));
}
@Test
public void testReplaceAt_ConstantNode() {
ConstantNode n1 = integerConstant(9);
ConstantNode n2 = integerConstant(5);
assertEquals(n1, NodeWalk.replaceAt(n1, 0, t -> t));
assertEquals(n2, NodeWalk.replaceAt(n1, 0, t -> n2));
}
@Test
public void testReplaceAll_ConstantNode() {
ConstantNode n1 = integerConstant(9);
ConstantNode n2 = integerConstant(5);
Function<Node, Node> replacement = n -> n2;
assertSame(n2, NodeWalk.replaceAll(n1, n -> n == n1, replacement));
assertSame(n1, NodeWalk.replaceAll(n1, n -> n == n2, replacement));
}
@Test
public void testGet_ConstantNode() {
ConstantNode c = integerConstant(9);
assertSame(c, NodeWalk.getAt(c, 0));
}
@Test
public void testReplaceAt_FunctionNode() {
FunctionNode n = createFunctionNode();
java.util.function.Function<Node, Node> replacement = t -> integerConstant(9);
assertEquals("(+ (* 9 v1) (+ v2 1))", NodeWalk.replaceAt(n, 0, replacement).toString());
assertEquals("(+ (* v0 9) (+ v2 1))", NodeWalk.replaceAt(n, 1, replacement).toString());
assertEquals("(+ 9 (+ v2 1))", NodeWalk.replaceAt(n, 2, replacement).toString());
assertEquals("(+ (* v0 v1) (+ 9 1))", NodeWalk.replaceAt(n, 3, replacement).toString());
assertEquals("(+ (* v0 v1) (+ v2 9))", NodeWalk.replaceAt(n, 4, replacement).toString());
assertEquals("(+ (* v0 v1) 9)", NodeWalk.replaceAt(n, 5, replacement).toString());
assertEquals("9", NodeWalk.replaceAt(n, 6, replacement).toString());
}
@Test
public void testReplaceAll_FunctionNode() {
Node input = readNode("(- (- (* -1 v3) 0) (- 13 v1))");
ConstantNode integerConstant = integerConstant(42);
java.util.function.Function<Node, Node> replacement = n -> integerConstant;
assertSame(input, NodeWalk.replaceAll(input, n -> false, replacement));
assertSame(integerConstant, NodeWalk.replaceAll(input, n -> true, replacement));
assertNodeEquals("(- (- (* -1 42) 0) (- 13 42))", NodeWalk.replaceAll(input, n -> isVariable(n), replacement));
assertNodeEquals("(- (- (* 42 v3) 42) (- 42 v1))", NodeWalk.replaceAll(input, n -> isConstant(n), replacement));
Predicate<Node> criteria = n -> isFunction(n) && ((FunctionNode) n).getFunction() == INTEGER_UTILS.getSubtract();
assertNodeEquals("(+ (+ (* -1 v3) 0) (+ 13 v1))",
NodeWalk.replaceAll(input, criteria, n -> new FunctionNode(INTEGER_UTILS.getAdd(), ((FunctionNode) n).getArguments())));
}
@Test
public void testGetAt_FunctionNode() {
FunctionNode n = createFunctionNode();
assertEquals("v0", NodeWalk.getAt(n, 0).toString());
assertEquals("v1", NodeWalk.getAt(n, 1).toString());
assertEquals("(* v0 v1)", NodeWalk.getAt(n, 2).toString());
assertEquals("v2", NodeWalk.getAt(n, 3).toString());
assertEquals("1", NodeWalk.getAt(n, 4).toString());
assertEquals("(+ v2 1)", NodeWalk.getAt(n, 5).toString());
assertEquals("(+ (* v0 v1) (+ v2 1))", NodeWalk.getAt(n, 6).toString());
}
/** Returns representation of: {@code (x*y)+z+1} */
private FunctionNode createFunctionNode() {
return new FunctionNode(INTEGER_UTILS.getAdd(), new FunctionNode(INTEGER_UTILS.getMultiply(), createVariable(0), createVariable(1)), new FunctionNode(
INTEGER_UTILS.getAdd(), createVariable(2), integerConstant(1)));
}
}