/*
* Copyright 2015 S. Webber
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.oakgp.function.math;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.fail;
import static org.oakgp.TestUtils.assertNodeEquals;
import static org.oakgp.TestUtils.readFunctionNode;
import static org.oakgp.TestUtils.readNode;
import java.util.Optional;
import org.junit.Test;
import org.oakgp.Arguments;
import org.oakgp.Assignments;
import org.oakgp.node.FunctionNode;
import org.oakgp.node.Node;
public class ArithmeticExpressionSimplifierTest {
private static final ArithmeticExpressionSimplifier SIMPLIFIER = IntegerUtils.INTEGER_UTILS.getSimplifier();
@Test
public void testCombineWithChildNodes() {
// constants
assertCombineWithChildNodes("3", "7", true, "10");
assertCombineWithChildNodes("3", "7", false, "-4");
// adding constant to a function
assertCombineWithChildNodes("(+ 1 v0)", "7", true, "(+ 8 v0)");
assertCombineWithChildNodes("(+ 1 v0)", "7", false, "(+ -6 v0)");
assertCombineWithChildNodes("(+ 1 (- (- v0 9) 8))", "7", true, "(+ 8 (- (- v0 9) 8))");
assertCombineWithChildNodes("(- 1 v0)", "7", true, "(- 8 v0)");
assertCombineWithChildNodes("(- 1 v0)", "7", false, "(- -6 v0)");
assertCombineWithChildNodes("(- 1 (- (- v0 9) 8))", "7", true, "(- 8 (- (- v0 9) 8))");
// adding variable to function
assertCombineWithChildNodes("(+ 1 (- v0 9))", "v0", true, "(+ 1 (- (* 2 v0) 9))");
assertCombineWithChildNodes("(+ 1 (- v0 9))", "v0", false, "(+ 1 (- 0 9))");
// multiplication of variable
assertCombineWithChildNodes("(* 3 v0)", "v0", true, "(* 4 v0)");
assertCombineWithChildNodes("(* 3 v0)", "v0", false, "(* 2 v0)");
assertCombineWithChildNodes("(* -3 v0)", "v0", true, "(* -2 v0)");
assertCombineWithChildNodes("(* -3 v0)", "v0", false, "(* -4 v0)");
// combination of multiplication of the same variable
assertCombineWithChildNodes("(* 3 v0)", "(* 7 v0)", true, "(* 10 v0)");
assertCombineWithChildNodes("(* 3 v0)", "(* -7 v0)", true, "(* -4 v0)");
assertCombineWithChildNodes("(* -3 v0)", "(* 7 v0)", true, "(* 4 v0)");
assertCombineWithChildNodes("(* -3 v0)", "(* -7 v0)", true, "(* -10 v0)");
assertCombineWithChildNodes("(* 3 v0)", "(* 7 v0)", false, "(* -4 v0)");
assertCombineWithChildNodes("(* 3 v0)", "(* -7 v0)", false, "(* 10 v0)");
assertCombineWithChildNodes("(* -3 v0)", "(* 7 v0)", false, "(* -10 v0)");
assertCombineWithChildNodes("(* -3 v0)", "(* -7 v0)", false, "(* 4 v0)");
// adding to a sub-node of a function
assertCombineWithChildNodes("(+ 1 (- v0 9))", "v0", true, "(+ 1 (- (* 2 v0) 9))");
assertCombineWithChildNodes("(+ 1 (- v0 9))", "v0", false, "(+ 1 (- 0 9))");
assertCombineWithChildNodes("(+ 1 (* 2 v0))", "v0", true, "(+ 1 (* 3 v0))");
assertCombineWithChildNodes("(+ 1 (* 2 v0))", "v0", false, "(+ 1 (* 1 v0))");
assertCombineWithChildNodes("(+ 1 (* 2 v0))", "(* 3 v0)", true, "(+ 1 (* 5 v0))");
assertCombineWithChildNodes("(+ 1 (* 2 v0))", "(* 3 v0)", false, "(+ 1 (* -1 v0))");
assertCombineWithChildNodes("(+ 1 (- v0 9))", "(- v0 9)", true, "(+ 1 (* 2 (- v0 9)))");
assertCombineWithChildNodes("(+ 1 (- 8 (- v0 9)))", "(- v0 9)", true, "(+ 1 (- 8 0))");
assertCombineWithChildNodes("(+ 1 (- (- v0 9) 8))", "(- v0 9)", true, "(+ 1 (- (* 2 (- v0 9)) 8))");
assertCannotCombineWithChildNodes("(- v0 9)", "(+ 1 (- v0 9))");
assertCannotCombineWithChildNodes("(* 3 v0)", "v1");
assertCannotCombineWithChildNodes("(* v0 v1)", "7");
}
private void assertCombineWithChildNodes(String first, String second, boolean isPos, String expected) {
Node result = SIMPLIFIER.combineWithChildNodes(readNode(first), readNode(second), isPos);
assertNodeEquals(expected, result);
}
private void assertCannotCombineWithChildNodes(String first, String second) {
assertNull(SIMPLIFIER.combineWithChildNodes(readNode(first), readNode(second), true));
assertNull(SIMPLIFIER.combineWithChildNodes(readNode(first), readNode(second), false));
}
@Test
public void testSimplify() {
assertSimplify("(+ 1 1)", "(+ 1 1)");
assertSimplify("(- 1 1)", "(- 1 1)");
assertAdditionSimplification("v0", "(+ 1 v0)", "(+ 1 (* 2 v0))");
assertAdditionSimplification("v0", "(+ v1 (+ v1 (+ v0 9)))", "(+ v1 (+ v1 (+ (* 2 v0) 9)))");
assertAdditionSimplification("v1", "(+ v1 (+ v1 (+ v0 9)))", "(+ (* 2 v1) (+ v1 (+ v0 9)))");
assertAdditionSimplification("v0", "(* 1 v0)", "(* 2 v0)");
assertSimplify("(- 1 1)", "(- 1 1)");
assertSimplify("(+ v0 (- 1 v0))", "(- 1 0)");
assertSimplify("(- v0 (- v1 (- v0 9)))", "(- 0 (- v1 (- (* 2 v0) 9)))");
assertSimplify("(- v0 (- v1 (- v1 (- v0 9))))", "(- 0 (- v1 (- v1 (- 0 9))))");
assertAdditionSimplification("9", "(+ v0 3)", "(+ v0 12)");
assertAdditionSimplification("9", "(- v0 3)", "(- v0 -6)");
assertSimplify("(- 4 (- v1 (- v0 9)))", "(- 0 (- v1 (- v0 5)))");
assertSimplify("(- 4 (- v1 (+ v0 9)))", "(- 0 (- v1 (+ v0 13)))");
assertSimplify("(- (+ 4 v0) 3)", "(+ 1 v0)");
assertSimplify("(- (- v0 1) v1)", "(- (- v0 1) v1)");
assertSimplify("(- (- v0 1) (- v0 1))", "(- (- 0 0) (- 1 1))");
assertSimplify("(- (+ v0 1) (+ v0 1))", "(- (+ 0 0) (+ -1 1))");
assertSimplify("(+ (- v0 1) (- v0 1))", "(+ (- 0 0) (- (* 2 v0) 2))");
assertSimplify("(+ (+ v0 1) (+ v0 1))", "(+ (+ 0 0) (+ (* 2 v0) 2))");
assertSimplify("(- (+ v0 1) (- v0 1))", "(- (+ 0 0) (- -1 1))");
}
private void assertAdditionSimplification(String firstArg, String secondArg, String expectedOutput) {
assertSimplify("(+ " + firstArg + " " + secondArg + ")", expectedOutput);
}
private void assertSimplify(String input, String expectedOutput) {
FunctionNode in = readFunctionNode(input);
Arguments args = in.getArguments();
Node simplifiedVersion = simplify(in, args).orElse(in);
assertNodeEquals(expectedOutput, simplifiedVersion);
if (!simplifiedVersion.equals(in)) {
int[][] assignedValues = { { 0, 0 }, { 1, 21 }, { 2, 14 }, { 3, -6 }, { 7, 3 }, { -1, 9 }, { -7, 0 } };
for (int[] assignedValue : assignedValues) {
Assignments assignments = Assignments.createAssignments(assignedValue[0], assignedValue[1]);
if (!in.evaluate(assignments).equals(simplifiedVersion.evaluate(assignments))) {
throw new RuntimeException(expectedOutput);
}
}
}
}
private Optional<Node> simplify(FunctionNode in, Arguments args) {
return Optional.ofNullable(SIMPLIFIER.simplify(in.getFunction(), args.firstArg(), args.secondArg()));
}
@Test
public void testEvaluateToSameResultSuccess() {
Node a = readNode("(* 7 (+ 1 2))");
Node b = readNode("(+ 9 12)");
ArithmeticExpressionSimplifier.assertEvaluateToSameResult(a, b);
}
@Test
public void testEvaluateToSameResultFailure() {
Node a = readNode("(* 7 (- 1 2))");
Node b = readNode("(+ 9 12)");
try {
ArithmeticExpressionSimplifier.assertEvaluateToSameResult(a, b);
fail();
} catch (IllegalArgumentException e) {
assertEquals("(* 7 (- 1 2)) = -7 (+ 9 12) = 21", e.getMessage());
}
}
}