/* * Encog(tm) Core v3.4 - Java Version * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-core * Copyright 2008-2016 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.app.analyst.csv.process; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import org.encog.app.analyst.AnalystError; import org.encog.app.analyst.csv.basic.LoadedRow; import org.encog.ml.prg.ProgramNode; import org.encog.ml.prg.expvalue.ExpressionValue; import org.encog.ml.prg.extension.BasicTemplate; import org.encog.ml.prg.extension.EncogOpcodeRegistry; import org.encog.ml.prg.extension.FunctionFactory; import org.encog.ml.prg.extension.NodeType; import org.encog.ml.prg.extension.ProgramExtensionTemplate; import org.encog.util.csv.CSVFormat; import org.encog.util.csv.ReadCSV; public class ProcessExtension { public final static String EXTENSION_DATA_NAME = "ENCOG-ANALYST-PROCESS"; private Map<String, Integer> map = new HashMap<String, Integer>(); private int forwardWindowSize; private int backwardWindowSize; private int totalWindowSize; private List<LoadedRow> data = new ArrayList<LoadedRow>(); private final CSVFormat format; // add field public static final ProgramExtensionTemplate OPCODE_FIELD = new BasicTemplate( ProgramExtensionTemplate.NO_PREC, "field({s}{i}):{s}", NodeType.Function, true, 0) { /** * The serial id. */ private static final long serialVersionUID = 1L; @Override public ExpressionValue evaluate(ProgramNode actual) { ProcessExtension pe = (ProcessExtension) actual.getOwner() .getExtraData(EXTENSION_DATA_NAME); String fieldName = actual.getChildNode(0).evaluate() .toStringValue(); int fieldIndex = (int) actual.getChildNode(1).evaluate() .toFloatValue() + pe.getBackwardWindowSize(); String value = pe.getField(fieldName, fieldIndex); return new ExpressionValue(value); } }; // add fieldmax public static final ProgramExtensionTemplate OPCODE_FIELDMAX = new BasicTemplate( ProgramExtensionTemplate.NO_PREC, "fieldmax({s}{i}{i}):{f}", NodeType.Function, true, 0) { /** * The serial id. */ private static final long serialVersionUID = 1L; @Override public ExpressionValue evaluate(ProgramNode actual) { ProcessExtension pe = (ProcessExtension) actual.getOwner() .getExtraData(EXTENSION_DATA_NAME); String fieldName = actual.getChildNode(0).evaluate() .toStringValue(); int startIndex = (int) actual.getChildNode(1).evaluate() .toIntValue(); int stopIndex = (int) actual.getChildNode(2).evaluate() .toIntValue(); double value = Double.NEGATIVE_INFINITY; for (int i = startIndex; i <= stopIndex; i++) { String str = pe.getField(fieldName, pe.getBackwardWindowSize() + i); double d = pe.getFormat().parse(str); value = Math.max(d, value); } return new ExpressionValue(value); } }; // add fieldmaxpip public static final ProgramExtensionTemplate OPCODE_FIELDMAXPIP = new BasicTemplate( ProgramExtensionTemplate.NO_PREC, "fieldmaxpip({s}{i}{i}):{f}", NodeType.Function, true, 0) { /** * The serial id. */ private static final long serialVersionUID = 1L; @Override public ExpressionValue evaluate(ProgramNode actual) { ProcessExtension pe = (ProcessExtension) actual.getOwner() .getExtraData(EXTENSION_DATA_NAME); String fieldName = actual.getChildNode(0).evaluate() .toStringValue(); int startIndex = (int) actual.getChildNode(1).evaluate() .toIntValue(); int stopIndex = (int) actual.getChildNode(2).evaluate() .toIntValue(); int value = Integer.MIN_VALUE; String str = pe.getField(fieldName, pe.getBackwardWindowSize()); double quoteNow = pe.getFormat().parse(str); for (int i = startIndex; i <= stopIndex; i++) { str = pe.getField(fieldName, pe.getBackwardWindowSize() + i); double d = pe.getFormat().parse(str) - quoteNow; d /= 0.0001; d = Math.round(d); value = Math.max((int) d, value); } return new ExpressionValue(value); } }; /** * Add opcodes to the Encog resource registry. */ static { EncogOpcodeRegistry.INSTANCE.add(OPCODE_FIELD); EncogOpcodeRegistry.INSTANCE.add(OPCODE_FIELDMAX); EncogOpcodeRegistry.INSTANCE.add(OPCODE_FIELDMAXPIP); } public ProcessExtension(CSVFormat theFormat) { this.format = theFormat; } public String getField(String fieldName, int fieldIndex) { if (!map.containsKey(fieldName)) { throw new AnalystError("Unknown input field: " + fieldName); } int idx = map.get(fieldName); if (fieldIndex >= this.data.size() || fieldIndex < 0) { throw new AnalystError( "The specified temporal index " + fieldIndex + " is out of bounds. You should probably increase the forward window size."); } return this.data.get(fieldIndex).getData()[idx]; } public void loadRow(LoadedRow row) { data.add(0, row); if (data.size() > this.totalWindowSize) { data.remove(data.size() - 1); } } public void init(ReadCSV csv, int theBackwardWindowSize, int theForwardWindowSize) { this.forwardWindowSize = theForwardWindowSize; this.backwardWindowSize = theBackwardWindowSize; this.totalWindowSize = this.forwardWindowSize + this.backwardWindowSize + 1; int i = 0; for (String name : csv.getColumnNames()) { map.put(name, i++); } } public boolean isDataReady() { return this.data.size() >= this.totalWindowSize; } public int getForwardWindowSize() { return forwardWindowSize; } public int getBackwardWindowSize() { return backwardWindowSize; } public int getTotalWindowSize() { return totalWindowSize; } public CSVFormat getFormat() { return format; } public void register(FunctionFactory functions) { functions.addExtension(OPCODE_FIELD); functions.addExtension(OPCODE_FIELDMAX); functions.addExtension(OPCODE_FIELDMAXPIP); } }