/*
* Encog(tm) Core v3.4 - Java Version
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-core
* Copyright 2008-2016 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.ml.prg.train.rewrite;
import org.encog.Encog;
import org.encog.EncogError;
import org.encog.ml.ea.genome.Genome;
import org.encog.ml.ea.rules.RewriteRule;
import org.encog.ml.prg.EncogProgram;
import org.encog.ml.prg.ProgramNode;
import org.encog.ml.prg.expvalue.ExpressionValue;
import org.encog.ml.prg.extension.StandardExtensions;
/**
* This class is used to rewrite algebraic expressions into more simple forms.
* This is by no means a complete set of rewrite rules, and will likely be
* extended in the future.
*/
public class RewriteAlgebraic implements RewriteRule {
/**
* Has the expression been rewritten.
*/
private boolean rewritten;
/**
* Has this rewrite rule been validated? Are the required operators present?
*/
private boolean validated;
/**
* Validate that the required operators are present for this rewrite rule.
* @param g A genome.
*/
private void validate(final Genome g) {
if( !(g instanceof EncogProgram) ) {
throw new EncogError("RewriteAlgebraic can only be used with EncogProgram genomes.");
}
final EncogProgram prg = (EncogProgram) g;
if( !prg.getContext().getFunctions().isDefined("^",2) ) {
throw new EncogError("Must have power(^) operator to use RewriteAlgebraic");
}
if( !prg.getContext().getFunctions().isDefined("+",2) ) {
throw new EncogError("Must have addition(+) operator to use RewriteAlgebraic");
}
if( !prg.getContext().getFunctions().isDefined("-",1) ) {
throw new EncogError("Must have negative(-) operator to use RewriteAlgebraic");
}
if( !prg.getContext().getFunctions().isDefined("-",2) ) {
throw new EncogError("Must have subtraction(-) operator to use RewriteAlgebraic");
}
if( !prg.getContext().getFunctions().isDefined("*",2) ) {
throw new EncogError("Must have multiplication(*) operator to use RewriteAlgebraic");
}
if( !prg.getContext().getFunctions().isDefined("/",2) && !prg.getContext().getFunctions().isDefined("%",2) ) {
throw new EncogError("Must have division(/ or %) operator to use RewriteAlgebraic");
}
if( !prg.getContext().getFunctions().isDefined("#var",0) ) {
throw new EncogError("Must have variables(#var) operator to use RewriteAlgebraic");
}
}
/**
* Create an floating point numeric constant.
* @param prg The program to create the constant for.
* @param v The value that the constant represents.
* @return The newly created node.
*/
private ProgramNode createNumericConst(final EncogProgram prg,
final double v) {
final ProgramNode result = prg.getFunctions().factorProgramNode("#const",
prg, new ProgramNode[] {});
result.getData()[0] = new ExpressionValue(v);
return result;
}
/**
* Create an integer numeric constant.
* @param prg The program to create the constant for.
* @param v The value that the constant represents.
* @return The newly created node.
*/
private ProgramNode createNumericConst(final EncogProgram prg, final int v) {
final ProgramNode result = prg.getFunctions().factorProgramNode("#const",
prg, new ProgramNode[] {});
result.getData()[0] = new ExpressionValue(v);
return result;
}
/**
* Attempt to rewrite the specified node.
* @param parent The parent node to start from.
* @return The rewritten node, or the same node if no rewrite occurs.
*/
private ProgramNode internalRewrite(final ProgramNode parent) {
ProgramNode rewrittenParent = parent;
rewrittenParent = tryDoubleNegative(rewrittenParent);
rewrittenParent = tryMinusMinus(rewrittenParent);
rewrittenParent = tryPlusNeg(rewrittenParent);
rewrittenParent = tryVarOpVar(rewrittenParent);
rewrittenParent = tryPowerZero(rewrittenParent);
rewrittenParent = tryOnePower(rewrittenParent);
rewrittenParent = tryZeroPlus(rewrittenParent);
rewrittenParent = tryZeroDiv(rewrittenParent);
rewrittenParent = tryZeroMul(rewrittenParent);
rewrittenParent = tryMinusZero(rewrittenParent);
// try children
for (int i = 0; i < rewrittenParent.getChildNodes().size(); i++) {
final ProgramNode childNode = (ProgramNode) rewrittenParent
.getChildNodes().get(i);
final ProgramNode rewriteChild = internalRewrite(childNode);
if (childNode != rewriteChild) {
rewrittenParent.getChildNodes().remove(i);
rewrittenParent.getChildNodes().add(i, rewriteChild);
this.rewritten = true;
}
}
return rewrittenParent;
}
/**
* Determine if the specified node is constant.
* @param node The node to check.
* @param v The constant to compare against.
* @return True if the specified node matches the specified constant.
*/
private boolean isConstValue(final ProgramNode node, final double v) {
if (node.getTemplate() == StandardExtensions.EXTENSION_CONST_SUPPORT) {
if (Math.abs(node.getData()[0].toFloatValue() - v) < Encog.DEFAULT_DOUBLE_EQUAL) {
return true;
}
}
return false;
}
/**
* {@inheritDoc}
*/
@Override
public boolean rewrite(final Genome g) {
// Validate that the program has needed operators.
if( !this.validated ) {
validate(g);
this.validated = true;
}
// Attempt the rule.
this.rewritten = false;
final EncogProgram program = (EncogProgram) g;
final ProgramNode node = program.getRootNode();
final ProgramNode rewrittenRoot = internalRewrite(node);
if (rewrittenRoot != null) {
program.setRootNode(rewrittenRoot);
}
return this.rewritten;
}
/**
* Try to rewrite --x.
* @param parent The parent node to attempt to rewrite.
* @return The rewritten node, if it was rewritten.
*/
private ProgramNode tryDoubleNegative(final ProgramNode parent) {
if (parent.getName().equals("-")) {
final ProgramNode child = parent.getChildNode(0);
if (child.getName().equals("-")) {
final ProgramNode grandChild = child.getChildNode(0);
this.rewritten = true;
return grandChild;
}
}
return parent;
}
/**
* Try to rewrite --x.
* @param parent The parent node to attempt to rewrite.
* @return The rewritten node, if it was rewritten.
*/
private ProgramNode tryMinusMinus(ProgramNode parent) {
if (parent.getName().equals("-") && parent.getChildNodes().size() == 2) {
final ProgramNode child1 = parent.getChildNode(0);
final ProgramNode child2 = parent.getChildNode(1);
if (child2.getName().equals("#const")) {
final ExpressionValue v = child2.getData()[0];
if (v.isFloat()) {
final double v2 = v.toFloatValue();
if (v2 < 0) {
child2.getData()[0] = new ExpressionValue(-v2);
parent = parent
.getOwner()
.getContext()
.getFunctions()
.factorProgramNode("+", parent.getOwner(),
new ProgramNode[] { child1, child2 });
}
} else if (v.isInt()) {
final long v2 = v.toIntValue();
if (v2 < 0) {
child2.getData()[0] = new ExpressionValue(-v2);
parent = parent
.getOwner()
.getContext()
.getFunctions()
.factorProgramNode("+", parent.getOwner(),
new ProgramNode[] { child1, child2 });
}
}
}
}
return parent;
}
/**
* Try to rewrite x-0.
* @param parent The parent node to attempt to rewrite.
* @return The rewritten node, if it was rewritten.
*/
private ProgramNode tryMinusZero(final ProgramNode parent) {
if (parent.getTemplate() == StandardExtensions.EXTENSION_SUB) {
final ProgramNode child2 = parent.getChildNode(1);
if (isConstValue(child2, 0)) {
final ProgramNode child1 = parent.getChildNode(0);
return child1;
}
}
return parent;
}
/**
* Try to rewrite x^1.
* @param parent The parent node to attempt to rewrite.
* @return The rewritten node, if it was rewritten.
*/
private ProgramNode tryOnePower(final ProgramNode parent) {
if (parent.getTemplate() == StandardExtensions.EXTENSION_POWER
|| parent.getTemplate() == StandardExtensions.EXTENSION_POWFN) {
final ProgramNode child = parent.getChildNode(0);
if (child.getTemplate() == StandardExtensions.EXTENSION_CONST_SUPPORT) {
if (Math.abs(child.getData()[0].toFloatValue() - 1) < Encog.DEFAULT_DOUBLE_EQUAL) {
this.rewritten = true;
return createNumericConst(parent.getOwner(), 1);
}
}
}
return parent;
}
/**
* Try to rewrite x+-c.
* @param parent The parent node to attempt to rewrite.
* @return The rewritten node, if it was rewritten.
*/
private ProgramNode tryPlusNeg(ProgramNode parent) {
if (parent.getName().equals("+") && parent.getChildNodes().size() == 2) {
final ProgramNode child1 = parent.getChildNode(0);
final ProgramNode child2 = parent.getChildNode(1);
if (child2.getName().equals("-")
&& child2.getChildNodes().size() == 1) {
parent = parent
.getOwner()
.getContext()
.getFunctions()
.factorProgramNode(
"-",
parent.getOwner(),
new ProgramNode[] { child1,
child2.getChildNode(0) });
} else if (child2.getName().equals("#const")) {
final ExpressionValue v = child2.getData()[0];
if (v.isFloat()) {
final double v2 = v.toFloatValue();
if (v2 < 0) {
child2.getData()[0] = new ExpressionValue(-v2);
parent = parent
.getOwner()
.getContext()
.getFunctions()
.factorProgramNode("-", parent.getOwner(),
new ProgramNode[] { child1, child2 });
}
} else if (v.isInt()) {
final long v2 = v.toIntValue();
if (v2 < 0) {
child2.getData()[0] = new ExpressionValue(-v2);
parent = parent
.getOwner()
.getContext()
.getFunctions()
.factorProgramNode("-", parent.getOwner(),
new ProgramNode[] { child1, child2 });
}
}
}
}
return parent;
}
/**
* Try to rewrite x^0.
* @param parent The parent node to attempt to rewrite.
* @return The rewritten node, if it was rewritten.
*/
private ProgramNode tryPowerZero(final ProgramNode parent) {
if (parent.getTemplate() == StandardExtensions.EXTENSION_POWER
|| parent.getTemplate() == StandardExtensions.EXTENSION_POWFN) {
final ProgramNode child0 = parent.getChildNode(0);
final ProgramNode child1 = parent.getChildNode(1);
if (isConstValue(child1, 0)) {
return createNumericConst(parent.getOwner(), 1);
}
if (isConstValue(child0, 0)) {
return createNumericConst(parent.getOwner(), 0);
}
}
return parent;
}
/**
* Try to rewrite x+x, x-x, x*x, x/x.
* @param parent The parent node to attempt to rewrite.
* @return The rewritten node, if it was rewritten.
*/
private ProgramNode tryVarOpVar(ProgramNode parent) {
if (parent.getChildNodes().size() == 2
&& parent.getName().length() == 1
&& "+-*/".indexOf(parent.getName().charAt(0)) != -1) {
final ProgramNode child1 = parent.getChildNode(0);
final ProgramNode child2 = parent.getChildNode(1);
if (child1.getName().equals("#var")
&& child2.getName().equals("#var")) {
if (child1.getData()[0].toIntValue() == child2.getData()[0]
.toIntValue()) {
switch (parent.getName().charAt(0)) {
case '-':
parent = createNumericConst(parent.getOwner(), 0);
break;
case '+':
parent = parent
.getOwner()
.getFunctions()
.factorProgramNode(
"*",
parent.getOwner(),
new ProgramNode[] {
createNumericConst(
parent.getOwner(), 2),
child1 });
break;
case '*':
parent = parent
.getOwner()
.getFunctions()
.factorProgramNode(
"^",
parent.getOwner(),
new ProgramNode[] {
child1,
createNumericConst(
parent.getOwner(), 2) });
break;
case '/':
parent = createNumericConst(parent.getOwner(), 1);
break;
}
}
}
}
return parent;
}
/**
* Try to rewrite 0/x.
* @param parent The parent node to attempt to rewrite.
* @return The rewritten node, if it was rewritten.
*/
private ProgramNode tryZeroDiv(final ProgramNode parent) {
if (parent.getTemplate() == StandardExtensions.EXTENSION_DIV) {
final ProgramNode child1 = parent.getChildNode(0);
final ProgramNode child2 = parent.getChildNode(1);
if (!isConstValue(child2, 0)) {
if (isConstValue(child1, 0)) {
this.rewritten = true;
return this.createNumericConst(parent.getOwner(), 0);
}
}
}
return parent;
}
/**
* Try to rewrite 0*x.
* @param parent The parent node to attempt to rewrite.
* @return The rewritten node, if it was rewritten.
*/
private ProgramNode tryZeroMul(final ProgramNode parent) {
if (parent.getTemplate() == StandardExtensions.EXTENSION_MUL) {
final ProgramNode child1 = parent.getChildNode(0);
final ProgramNode child2 = parent.getChildNode(1);
if (isConstValue(child1, 0) || isConstValue(child2, 0)) {
this.rewritten = true;
return this.createNumericConst(parent.getOwner(), 0);
}
}
return parent;
}
/**
* Try to rewrite 0+x.
* @param parent The parent node to attempt to rewrite.
* @return The rewritten node, if it was rewritten.
*/
private ProgramNode tryZeroPlus(final ProgramNode parent) {
if (parent.getTemplate() == StandardExtensions.EXTENSION_ADD) {
final ProgramNode child1 = parent.getChildNode(0);
final ProgramNode child2 = parent.getChildNode(1);
if (isConstValue(child1, 0)) {
this.rewritten = true;
return child2;
}
if (isConstValue(child2, 0)) {
this.rewritten = true;
return child1;
}
}
return parent;
}
}