/* * Copyright 2015 S. Webber * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.oakgp.function.math; import static org.oakgp.node.NodeType.areFunctions; import static org.oakgp.node.NodeType.isConstant; import static org.oakgp.node.NodeType.isFunction; import static org.oakgp.util.NodeComparator.NODE_COMPARATOR; import org.oakgp.Arguments; import org.oakgp.Assignments; import org.oakgp.function.Function; import org.oakgp.node.ConstantNode; import org.oakgp.node.FunctionNode; import org.oakgp.node.Node; final class ArithmeticExpressionSimplifier { private static final boolean SANITY_CHECK = true; private final NumberUtils<?> numberUtils; ArithmeticExpressionSimplifier(NumberUtils<?> numberUtils) { this.numberUtils = numberUtils; } /** @return {@code null} if it was not possible to simplify the expression. */ Node simplify(Function function, Node firstArg, Node secondArg) { sanityCheck(() -> { assertAddOrSubtract(function); assertArgumentsOrdered(function, firstArg, secondArg); }); return getSimplifiedVersion(function, firstArg, secondArg); } private Node getSimplifiedVersion(Function function, Node firstArg, Node secondArg) { boolean isPos = numberUtils.isAdd(function); if (areFunctions(firstArg, secondArg)) { NodePair p = removeFromChildNodes(firstArg, secondArg, isPos); if (p != null) { return new FunctionNode(function, p.nodeThatHasBeenReduced, p.nodeThatHasBeenExpanded); } p = removeFromChildNodes(secondArg, firstArg, isPos); if (p != null) { return new FunctionNode(function, p.nodeThatHasBeenExpanded, p.nodeThatHasBeenReduced); } } else if (isFunction(firstArg)) { return combineWithChildNodes(firstArg, secondArg, isPos); } else if (isFunction(secondArg)) { // 3, (+ (* 12 v2) 30) -> (+ (* 12 v2) 33) Node tmp = combineWithChildNodes(secondArg, firstArg, isPos); if (tmp != null && numberUtils.isSubtract(function)) { // 3, (- (* 12 v2) 30) -> (- (* 12 v2) 33) -> (0 - (- (* 12 v2) 33)) return new FunctionNode(function, numberUtils.zero(), tmp); } else { return tmp; } } return null; } /** * Returns the result of removing the second argument from the first argument. * * @param nodeToWalk * tree structure to walk and remove the node from * @param nodeToRemove * the node to remove from {@code nodeToWalk} * @param isPos * {@code true} to indicate that {@code nodeToRemove} should be removed from {@code nodeToWalk}, else {@code false} to indicate that * {@code nodeToAdd} should be added to {@code nodeToWalk} * @return {@code null} if it was not possible to remove (@code nodeToRemove} from {@code nodeToWalk} */ private NodePair removeFromChildNodes(final Node nodeToWalk, final Node nodeToRemove, final boolean isPos) { if (numberUtils.isArithmeticExpression(nodeToWalk)) { FunctionNode fn = (FunctionNode) nodeToWalk; Function f = fn.getFunction(); Node firstArg = fn.getArguments().firstArg(); Node secondArg = fn.getArguments().secondArg(); if (numberUtils.isMultiply(f) && isFunction(nodeToRemove)) { FunctionNode x = (FunctionNode) nodeToRemove; Arguments a = x.getArguments(); if (numberUtils.isMultiply(x) && isConstant(firstArg) && isConstant(a.firstArg()) && secondArg.equals(a.secondArg())) { ConstantNode result; if (isPos) { result = numberUtils.add(a.firstArg(), firstArg); } else { result = numberUtils.subtract(a.firstArg(), firstArg); } Node tmp = new FunctionNode(f, result, secondArg); return new NodePair(numberUtils.zero(), tmp); } Node tmp = combineWithChildNodes(nodeToRemove, nodeToWalk, isPos); if (tmp != null) { return new NodePair(numberUtils.zero(), tmp); } } boolean isSubtract = numberUtils.isSubtract(f); if (numberUtils.isAdd(f) || isSubtract) { NodePair p = removeFromChildNodes(firstArg, nodeToRemove, isPos); if (p != null) { NodePair p2 = removeFromChildNodes(secondArg, p.nodeThatHasBeenExpanded, isSubtract ? !isPos : isPos); if (p2 == null) { return new NodePair(new FunctionNode(f, p.nodeThatHasBeenReduced, secondArg), p.nodeThatHasBeenExpanded); } else { return new NodePair(new FunctionNode(f, p.nodeThatHasBeenReduced, p2.nodeThatHasBeenReduced), p2.nodeThatHasBeenExpanded); } } p = removeFromChildNodes(secondArg, nodeToRemove, isSubtract ? !isPos : isPos); if (p != null) { return new NodePair(new FunctionNode(f, firstArg, p.nodeThatHasBeenReduced), p.nodeThatHasBeenExpanded); } } } else if (!numberUtils.isZero(nodeToWalk)) { Node tmp = combineWithChildNodes(nodeToRemove, nodeToWalk, isPos); if (tmp != null) { return new NodePair(numberUtils.zero(), tmp); } } return null; } /** * Returns the result of merging the second argument into the first argument. * * @param nodeToWalk * tree structure to walk and remove the node to * @param nodeToAdd * the node to remove from {@code nodeToWalk} * @param isPos * {@code true} to indicate that {@code nodeToAdd} should be added to {@code nodeToWalk}, else {@code false} to indicate that {@code nodeToAdd} * should be subtracted from {@code nodeToWalk} * @return {@code null} if it was not possible to merge (@code nodeToAdd} into {@code nodeToWalk} */ Node combineWithChildNodes(final Node nodeToWalk, final Node nodeToAdd, final boolean isPos) { if (isSuitableForCombining(nodeToWalk, nodeToAdd)) { return combine(nodeToWalk, nodeToAdd, isPos); } if (!numberUtils.isArithmeticExpression(nodeToWalk)) { return null; } FunctionNode currentFunctionNode = (FunctionNode) nodeToWalk; Node firstArg = currentFunctionNode.getArguments().firstArg(); Node secondArg = currentFunctionNode.getArguments().secondArg(); Function currentFunction = currentFunctionNode.getFunction(); boolean isAdd = numberUtils.isAdd(currentFunction); boolean isSubtract = numberUtils.isSubtract(currentFunction); if (isAdd || isSubtract) { boolean recursiveIsPos = isPos; if (isSubtract) { recursiveIsPos = !isPos; } if (isSuitableForCombining(firstArg, nodeToAdd)) { return new FunctionNode(currentFunction, combine(firstArg, nodeToAdd, isPos), secondArg); } else if (isSuitableForCombining(secondArg, nodeToAdd)) { return new FunctionNode(currentFunction, firstArg, combine(secondArg, nodeToAdd, recursiveIsPos)); } Node tmp = combineWithChildNodes(firstArg, nodeToAdd, isPos); if (tmp != null) { return new FunctionNode(currentFunction, tmp, secondArg); } tmp = combineWithChildNodes(secondArg, nodeToAdd, recursiveIsPos); if (tmp != null) { return new FunctionNode(currentFunction, firstArg, tmp); } } else if (numberUtils.isMultiply(currentFunction) && isConstant(firstArg) && secondArg.equals(nodeToAdd)) { ConstantNode multiplier; if (isPos) { multiplier = numberUtils.increment(firstArg); } else { multiplier = numberUtils.decrement(firstArg); } return new FunctionNode(currentFunction, multiplier, nodeToAdd); } else if (isMultiplyingTheSameValue(nodeToWalk, nodeToAdd)) { return combineMultipliers(nodeToWalk, nodeToAdd, isPos); } return null; } /** * Returns {@code true} if the specified nodes can be combined into a single node. * <p> * Two constants (even if they have different values) can be combined. e.g. {@code 9} and {@code 12} can be combined to form {@code 21} * </p> * <p> * Any two node that are {@code equal} can be combined. e.g. {@code v0} and {@code v0} can be combined to form {@code (* 2 v0)}, {@code (- 8 v0)} and * {@code (- 8 v0)} can be combined to form {@code (* 2 (- 8 v0))} * </p> */ private static boolean isSuitableForCombining(Node currentNode, Node nodeToReplace) { if (isConstant(nodeToReplace)) { return isConstant(currentNode); } else { return nodeToReplace.equals(currentNode); } } /** * Returns a node that is the result of combining the two specified nodes. * <p> * e.g. {@code 9} and {@code 12} can be combined to form {@code 21}, {@code v0} and {@code v0} can be combined to form {@code (* 2 v0)} * </p> * * @param isPos * {@code true} to indicate that {@code second} should be added to {@code first}, else {@code false} to indicate that {@code second} should be * subtracted from {@code first} */ private Node combine(Node first, Node second, boolean isPos) { sanityCheck(() -> assertSameClass(first, second)); if (isConstant(second)) { if (isPos) { return numberUtils.add(first, second); } else { return numberUtils.subtract(first, second); } } else { if (isPos) { return numberUtils.multiplyByTwo(second); } else { return numberUtils.zero(); } } } /** * Returns {@code true} if both of the specified nodes represent multiplication of the same value by a constant. * <p> * Examples of arguments that would return true: {@code (* 3 v0), (* 7 v0)} or {@code (* 1 v0), (* -8 v0)} * </p> * <p> * Examples of arguments that would return false: {@code (* 3 v0), (+ 7 v0)} or {@code (* 1 v0), (* -8 v1)} * </p> */ private boolean isMultiplyingTheSameValue(Node n1, Node n2) { if (areFunctions(n1, n2)) { FunctionNode f1 = (FunctionNode) n1; FunctionNode f2 = (FunctionNode) n2; if (numberUtils.isMultiply(f1) && numberUtils.isMultiply(f2) && isConstant(f1.getArguments().firstArg()) && isConstant(f2.getArguments().firstArg()) && f1.getArguments().secondArg().equals(f2.getArguments().secondArg())) { return true; } } return false; } /** e.g. arguments: {@code (* 3 v0), (* 7 v0)} would produce: {@code (* 10 v0)} */ private Node combineMultipliers(Node n1, Node n2, boolean isPos) { FunctionNode f1 = (FunctionNode) n1; FunctionNode f2 = (FunctionNode) n2; ConstantNode result; if (isPos) { result = numberUtils.add(f1.getArguments().firstArg(), f2.getArguments().firstArg()); } else { result = numberUtils.subtract(f1.getArguments().firstArg(), f2.getArguments().firstArg()); } return new FunctionNode(f1.getFunction(), result, f1.getArguments().secondArg()); } private static void sanityCheck(Runnable r) { // TODO remove this method - only here to sanity check input during development if (SANITY_CHECK) { r.run(); } } private void assertAddOrSubtract(Function f) { if (!numberUtils.isAddOrSubtract(f)) { throw new IllegalArgumentException(f.getClass().getName()); } } private void assertArgumentsOrdered(Function f, Node firstArg, Node secondArg) { if (!numberUtils.isSubtract(f) && NODE_COMPARATOR.compare(firstArg, secondArg) > 0) { throw new IllegalArgumentException("arg1 " + firstArg + " arg2 " + secondArg); } } private static void assertSameClass(Node currentNode, Node nodeToReplace) { if (nodeToReplace.getClass() != currentNode.getClass()) { throw new IllegalArgumentException(nodeToReplace.getClass().getName() + " " + currentNode.getClass().getName()); } } /** * Asserts that the specified nodes evaluate to the same results. * * @param first * the node to compare to {@code second} * @param second * the node to compare to {@code first} * @throws IllegalArgumentException * if the specified nodes evaluate to different results */ static void assertEvaluateToSameResult(Node first, Node second) { Object[] assignedValues = { 2, 14, 4, 9, 7 }; Assignments assignments = Assignments.createAssignments(assignedValues); Object firstResult = first.evaluate(assignments); Object secondResult = second.evaluate(assignments); if (!firstResult.equals(secondResult)) { throw new IllegalArgumentException(first + " = " + firstResult + " " + second + " = " + secondResult); } } private static class NodePair { private final Node nodeThatHasBeenReduced; private final Node nodeThatHasBeenExpanded; NodePair(Node nodeThatHasBeenReduced, Node nodeThatHasBeenExpanded) { this.nodeThatHasBeenReduced = nodeThatHasBeenReduced; this.nodeThatHasBeenExpanded = nodeThatHasBeenExpanded; } } }