/* * Encog(tm) Core v3.4 - Java Version * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-core * Copyright 2008-2016 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.ml.data.auto; import java.io.Serializable; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import org.encog.EncogError; import org.encog.ml.data.MLData; import org.encog.ml.data.MLDataPair; import org.encog.ml.data.MLDataSet; import org.encog.ml.data.basic.BasicMLDataPair; import org.encog.util.csv.CSVFormat; import org.encog.util.csv.ReadCSV; public class AutoFloatDataSet implements Serializable, MLDataSet { private int sourceInputCount; private int sourceIdealCount; private int inputWindowSize; private int outputWindowSize; private List<AutoFloatColumn> columns = new ArrayList<AutoFloatColumn>(); private float normalizedMax = 1; private float normalizedMin = -1; private boolean normalizationEnabled = false; public class AutoFloatIterator implements Iterator<MLDataPair> { /** * The index that the iterator is currently at. */ private int currentIndex = 0; /** * {@inheritDoc} */ @Override public final boolean hasNext() { return this.currentIndex < AutoFloatDataSet.this.size(); } /** * {@inheritDoc} */ @Override public final MLDataPair next() { if (!hasNext()) { return null; } return AutoFloatDataSet.this.get(this.currentIndex++); } /** * {@inheritDoc} */ @Override public final void remove() { throw new EncogError("Called remove, unsupported operation."); } } public AutoFloatDataSet(int theInputCount, int theIdealCount, int theInputWindowSize, int theOutputWindowSize) { this.sourceInputCount = theInputCount; this.sourceIdealCount = theIdealCount; this.inputWindowSize = theInputWindowSize; this.outputWindowSize = theOutputWindowSize; } @Override public Iterator<MLDataPair> iterator() { return new AutoFloatIterator(); } @Override public int getIdealSize() { return this.sourceIdealCount * this.outputWindowSize; } @Override public int getInputSize() { return this.sourceInputCount * this.inputWindowSize; } @Override public boolean isSupervised() { return getIdealSize() > 0; } @Override public long getRecordCount() { if (this.columns.size() == 0) { return 0; } else { int totalRows = this.columns.get(0).getData().length; int windowSize = this.inputWindowSize + this.outputWindowSize; return (totalRows - windowSize) + 1; } } @Override public void getRecord(long index, MLDataPair pair) { int columnID = 0; // copy the input int inputIndex = 0; for (int i = 0; i < this.sourceInputCount; i++) { AutoFloatColumn column = this.columns.get(columnID++); for (int j = 0; j < this.inputWindowSize; j++) { if( this.normalizationEnabled ) { pair.getInputArray()[inputIndex++] = column.getNormalized((int) index + j, this.normalizedMin, this.normalizedMax); } else { pair.getInputArray()[inputIndex++] = column.getData()[(int) index + j]; } } } // copy the output int idealIndex = 0; for (int i = 0; i < this.sourceIdealCount; i++) { AutoFloatColumn column = this.columns.get(columnID++); for (int j = 0; j < this.outputWindowSize; j++) { if( this.normalizationEnabled ) { pair.getIdealArray()[idealIndex++] = column.getNormalized( (int) (this.inputWindowSize + index + j), this.normalizedMin, this.normalizedMax); } else { pair.getIdealArray()[idealIndex++] = column.getData()[ (int) (this.inputWindowSize + index + j)]; } } } } @Override public MLDataSet openAdditional() { return this; } @Override public void add(MLData data1) { throw new EncogError("Add's not supported by this dataset."); } @Override public void add(MLData inputData, MLData idealData) { throw new EncogError("Add's not supported by this dataset."); } @Override public void add(MLDataPair inputData) { throw new EncogError("Add's not supported by this dataset."); } @Override public void close() { } @Override public int size() { return (int)getRecordCount(); } @Override public MLDataPair get(int index) { if( index>=size() ) { return null; } MLDataPair result = BasicMLDataPair.createPair(getInputSize(), this.getIdealSize()); getRecord(index, result); return result; } public void addColumn(float[] data) { AutoFloatColumn column = new AutoFloatColumn(data); this.columns.add(column); } public void loadCSV(String filename, boolean headers, CSVFormat format, int[] input, int[] ideal) { // first, just size it up ReadCSV csv = new ReadCSV(filename,headers,format); int lineCount = 0; while(csv.next()) { lineCount++; } csv.close(); // allocate space to hold it float[][] data = new float[input.length+ideal.length][lineCount]; // now read the data in csv = new ReadCSV(filename,headers,format); int rowIndex = 0; while(csv.next()) { int columnIndex = 0; for(int i=0;i<input.length;i++) { data[columnIndex++][rowIndex] = (float)csv.getDouble(input[i]); } for(int i=0;i<ideal.length;i++) { data[columnIndex++][rowIndex] = (float)csv.getDouble(ideal[i]); } rowIndex++; } csv.close(); // now add the columns for(int i=0;i<data.length;i++) { addColumn(data[i]); } } /** * @return the normalizedMax */ public float getNormalizedMax() { return normalizedMax; } /** * @param normalizedMax the normalizedMax to set */ public void setNormalizedMax(float normalizedMax) { this.normalizedMax = normalizedMax; this.normalizationEnabled = true; } /** * @return the normalizedMin */ public float getNormalizedMin() { return normalizedMin; } /** * @param normalizedMin the normalizedMin to set */ public void setNormalizedMin(float normalizedMin) { this.normalizedMin = normalizedMin; this.normalizationEnabled = true; } /** * @return the normalizationEnabled */ public boolean isNormalizationEnabled() { return normalizationEnabled; } /** * @param normalizationEnabled the normalizationEnabled to set */ public void setNormalizationEnabled(boolean normalizationEnabled) { this.normalizationEnabled = normalizationEnabled; } }