/* * Encog(tm) Core v3.4 - Java Version * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-core * Copyright 2008-2016 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.app.generate.generators.js; import java.io.File; import org.encog.Encog; import org.encog.EncogError; import org.encog.app.generate.AnalystCodeGenerationError; import org.encog.app.generate.generators.AbstractGenerator; import org.encog.app.generate.program.EncogGenProgram; import org.encog.app.generate.program.EncogProgramNode; import org.encog.app.generate.program.EncogTreeNode; import org.encog.engine.network.activation.ActivationElliott; import org.encog.engine.network.activation.ActivationFunction; import org.encog.engine.network.activation.ActivationLinear; import org.encog.engine.network.activation.ActivationSigmoid; import org.encog.engine.network.activation.ActivationTANH; import org.encog.ml.MLFactory; import org.encog.ml.MLMethod; import org.encog.ml.data.MLData; import org.encog.ml.data.MLDataPair; import org.encog.ml.data.MLDataSet; import org.encog.neural.flat.FlatNetwork; import org.encog.neural.networks.ContainsFlat; import org.encog.persist.EncogDirectoryPersistence; import org.encog.util.csv.CSVFormat; import org.encog.util.csv.NumberList; import org.encog.util.simple.EncogUtility; public class GenerateEncogJavaScript extends AbstractGenerator { private boolean embed; private void embedNetwork(final EncogProgramNode node) { addBreak(); final File methodFile = (File) node.getArgs().get(0).getValue(); final MLMethod method = (MLMethod) EncogDirectoryPersistence .loadObject(methodFile); if (!(method instanceof MLFactory)) { throw new EncogError("Code generation not yet supported for: " + method.getClass().getName()); } final FlatNetwork flat = ((ContainsFlat) method).getFlat(); // header final StringBuilder line = new StringBuilder(); line.append("public static MLMethod "); line.append(node.getName()); line.append("() {"); indentLine(line.toString()); // create factory line.setLength(0); addLine("var network = ENCOG.BasicNetwork.create( null );"); addLine("network.inputCount = " + flat.getInputCount() + ";"); addLine("network.outputCount = " + flat.getOutputCount() + ";"); addLine("network.layerCounts = " + toSingleLineArray(flat.getLayerCounts()) + ";"); addLine("network.layerContextCount = " + toSingleLineArray(flat.getLayerContextCount()) + ";"); addLine("network.weightIndex = " + toSingleLineArray(flat.getWeightIndex()) + ";"); addLine("network.layerIndex = " + toSingleLineArray(flat.getLayerIndex()) + ";"); addLine("network.activationFunctions = " + toSingleLineArray(flat.getActivationFunctions()) + ";"); addLine("network.layerFeedCounts = " + toSingleLineArray(flat.getLayerFeedCounts()) + ";"); addLine("network.contextTargetOffset = " + toSingleLineArray(flat.getContextTargetOffset()) + ";"); addLine("network.contextTargetSize = " + toSingleLineArray(flat.getContextTargetSize()) + ";"); addLine("network.biasActivation = " + toSingleLineArray(flat.getBiasActivation()) + ";"); addLine("network.beginTraining = " + flat.getBeginTraining() + ";"); addLine("network.endTraining=" + flat.getEndTraining() + ";"); addLine("network.weights = WEIGHTS;"); addLine("network.layerOutput = " + toSingleLineArray(flat.getLayerOutput()) + ";"); addLine("network.layerSums = " + toSingleLineArray(flat.getLayerSums()) + ";"); // return addLine("return network;"); unIndentLine("}"); } private void embedTraining(final EncogProgramNode node) { final File dataFile = (File) node.getArgs().get(0).getValue(); final MLDataSet data = EncogUtility.loadEGB2Memory(dataFile); // generate the input data indentLine("var INPUT_DATA = ["); for (final MLDataPair pair : data) { final MLData item = pair.getInput(); final StringBuilder line = new StringBuilder(); NumberList.toList(CSVFormat.EG_FORMAT, line, item.getData()); line.insert(0, "[ "); line.append(" ],"); addLine(line.toString()); } unIndentLine("];"); addBreak(); // generate the ideal data indentLine("var IDEAL_DATA = ["); for (final MLDataPair pair : data) { final MLData item = pair.getIdeal(); final StringBuilder line = new StringBuilder(); NumberList.toList(CSVFormat.EG_FORMAT, line, item.getData()); line.insert(0, "[ "); line.append(" ],"); addLine(line.toString()); } unIndentLine("];"); } @Override public void generate(final EncogGenProgram program, final boolean shouldEmbed) { if (!shouldEmbed) { throw new AnalystCodeGenerationError( "Must embed when generating Javascript"); } this.embed = shouldEmbed; generateForChildren(program); } private void generateArrayInit(final EncogProgramNode node) { final StringBuilder line = new StringBuilder(); line.append("var "); line.append(node.getName()); line.append(" = ["); indentLine(line.toString()); final double[] a = (double[]) node.getArgs().get(0).getValue(); line.setLength(0); int lineCount = 0; for (int i = 0; i < a.length; i++) { line.append(CSVFormat.EG_FORMAT.format(a[i], Encog.DEFAULT_PRECISION)); if (i < (a.length - 1)) { line.append(","); } lineCount++; if (lineCount >= 10) { addLine(line.toString()); line.setLength(0); lineCount = 0; } } if (line.length() > 0) { addLine(line.toString()); line.setLength(0); } unIndentLine("];"); } private void generateClass(final EncogProgramNode node) { addBreak(); addLine("<!DOCTYPE html>"); addLine("<html>"); addLine("<head>"); addLine("<title>Encog Generated Javascript</title>"); addLine("</head>"); addLine("<body>"); addLine("<script src=\"../encog.js\"></script>"); addLine("<script src=\"../encog-widget.js\"></script>"); addLine("<pre>"); addLine("<script type=\"text/javascript\">"); generateForChildren(node); addLine("</script>"); addLine("<noscript>Your browser does not support JavaScript! Note: if you are trying to view this in Encog Workbench, right-click file and choose \"Open as Text\".</noscript>"); addLine("</pre>"); addLine("</body>"); addLine("</html>"); } private void generateComment(final EncogProgramNode commentNode) { addLine("// " + commentNode.getName()); } private void generateConst(final EncogProgramNode node) { final StringBuilder line = new StringBuilder(); line.append("var "); line.append(node.getName()); line.append(" = \""); line.append(node.getArgs().get(0).getValue()); line.append("\";"); addLine(line.toString()); } private void generateForChildren(final EncogTreeNode parent) { for (final EncogProgramNode node : parent.getChildren()) { generateNode(node); } } private void generateFunction(final EncogProgramNode node) { addBreak(); final StringBuilder line = new StringBuilder(); line.append("function "); line.append(node.getName()); line.append("() {"); indentLine(line.toString()); generateForChildren(node); unIndentLine("}"); } private void generateFunctionCall(final EncogProgramNode node) { addBreak(); final StringBuilder line = new StringBuilder(); if (node.getArgs().get(0).getValue().toString().length() > 0) { line.append("var "); line.append(node.getArgs().get(1).getValue().toString()); line.append(" = "); } line.append(node.getName()); line.append("();"); addLine(line.toString()); } private void generateMainFunction(final EncogProgramNode node) { addBreak(); generateForChildren(node); } private void generateNode(final EncogProgramNode node) { switch (node.getType()) { case Comment: generateComment(node); break; case Class: generateClass(node); break; case MainFunction: generateMainFunction(node); break; case Const: generateConst(node); break; case StaticFunction: generateFunction(node); break; case FunctionCall: generateFunctionCall(node); break; case CreateNetwork: embedNetwork(node); break; case InitArray: generateArrayInit(node); break; case EmbedTraining: embedTraining(node); break; } } private String toSingleLineArray( final ActivationFunction[] activationFunctions) { final StringBuilder result = new StringBuilder(); result.append('['); for (int i = 0; i < activationFunctions.length; i++) { if (i > 0) { result.append(','); } final ActivationFunction af = activationFunctions[i]; if (af instanceof ActivationSigmoid) { result.append("ENCOG.ActivationSigmoid.create()"); } else if (af instanceof ActivationTANH) { result.append("ENCOG.ActivationTANH.create()"); } else if (af instanceof ActivationLinear) { result.append("ENCOG.ActivationLinear.create()"); } else if (af instanceof ActivationElliott) { result.append("ENCOG.ActivationElliott.create()"); } else { throw new AnalystCodeGenerationError( "Unsupported activatoin function for code generation: " + af.getClass().getSimpleName()); } } result.append(']'); return result.toString(); } private String toSingleLineArray(final double[] d) { final StringBuilder line = new StringBuilder(); line.append("["); for (int i = 0; i < d.length; i++) { line.append(CSVFormat.EG_FORMAT.format(d[i], Encog.DEFAULT_PRECISION)); if (i < (d.length - 1)) { line.append(","); } } line.append("]"); return line.toString(); } private String toSingleLineArray(final int[] d) { final StringBuilder line = new StringBuilder(); line.append("["); for (int i = 0; i < d.length; i++) { line.append(d[i]); if (i < (d.length - 1)) { line.append(","); } } line.append("]"); return line.toString(); } }