/* * Encog(tm) Core v3.4 - Java Version * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-core * Copyright 2008-2016 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.ml.data.versatile; import java.util.List; import org.encog.EncogError; import org.encog.mathutil.randomize.generate.GenerateRandom; import org.encog.ml.data.versatile.columns.ColumnDefinition; import org.encog.ml.data.versatile.columns.ColumnType; import org.encog.ml.data.versatile.division.DataDivision; import org.encog.ml.data.versatile.division.PerformDataDivision; import org.encog.ml.data.versatile.normalizers.strategies.NormalizationStrategy; import org.encog.ml.data.versatile.sources.VersatileDataSource; /** * The versatile dataset supports several advanced features. 1. it can directly * read and normalize from a CSV file. 2. It supports virtual time-boxing for * time series data (the data is NOT expanded in memory). 3. It can easily be * segmented into smaller datasets. */ public class VersatileMLDataSet extends MatrixMLDataSet { /** * The source that data is being pulled from. */ private VersatileDataSource source; /** * The normalization helper. */ private NormalizationHelper helper = new NormalizationHelper(); /** * The number of rows that were analyzed. */ private int analyzedRows; /** * Construct the data source. * @param theSource The data source. */ public VersatileMLDataSet(VersatileDataSource theSource) { this.source = theSource; } /** * Find the index of a column. * @param colDef The column. * @return The column index. */ private int findIndex(ColumnDefinition colDef) { if (colDef.getIndex() != -1) { return colDef.getIndex(); } int index = this.source.columnIndex(colDef.getName()); colDef.setIndex(index); if (index == -1) { throw new EncogError("Can't find column"); } return index; } /** * Analyze the input and determine max, min, mean, etc. */ public void analyze() { String[] line; // Collect initial stats: sums (for means), highs, lows. this.source.rewind(); int c = 0; while ((line = this.source.readLine()) != null) { c++; for (int i = 0; i < this.helper.getSourceColumns().size(); i++) { ColumnDefinition colDef = this.helper.getSourceColumns().get(i); int index = findIndex(colDef); String value = line[index]; colDef.analyze(value); } } this.analyzedRows = c; // Calculate the means, and reset for sd calc. for (ColumnDefinition colDef : this.helper.getSourceColumns()) { // Only calculate mean/sd for continuous columns. if (colDef.getDataType() == ColumnType.continuous) { colDef.setMean(colDef.getMean() / colDef.getCount()); colDef.setSd(0); } } // Sum the standard deviation this.source.rewind(); while ((line = this.source.readLine()) != null) { for (int i = 0; i < this.helper.getSourceColumns().size(); i++) { ColumnDefinition colDef = this.helper.getSourceColumns().get(i); String value = line[colDef.getIndex()]; if (colDef.getDataType() == ColumnType.continuous) { double d = this.helper.parseDouble(value); d = colDef.getMean() - d; d = d * d; colDef.setSd(colDef.getSd() + d); } } } // Calculate the standard deviations. for (ColumnDefinition colDef : this.helper.getSourceColumns()) { // Only calculate sd for continuous columns. if (colDef.getDataType() == ColumnType.continuous) { colDef.setSd(Math.sqrt(colDef.getSd() / colDef.getCount())); } } } /** * Normalize the data set, and allocate memory to hold it. */ public void normalize() { NormalizationStrategy strat = this.helper.getNormStrategy(); if (strat == null) { throw new EncogError( "Please choose a model type first, with selectMethod."); } int normalizedInputColumns = this.helper .calculateNormalizedInputCount(); int normalizedOutputColumns = this.helper .calculateNormalizedOutputCount(); int normalizedColumns = normalizedInputColumns + normalizedOutputColumns; setCalculatedIdealSize(normalizedOutputColumns); setCalculatedInputSize(normalizedInputColumns); setData(new double[this.analyzedRows][normalizedColumns]); this.source.rewind(); String[] line; int row = 0; while ((line = this.source.readLine()) != null) { int column = 0; for (ColumnDefinition colDef : this.helper.getInputColumns()) { int index = findIndex(colDef); String value = line[index]; column = this.helper.normalizeToVector(colDef, column, getData()[row], true, value); } for (ColumnDefinition colDef : this.helper.getOutputColumns()) { int index = findIndex(colDef); String value = line[index]; column = this.helper.normalizeToVector(colDef, column, getData()[row], false, value); } row++; } } /** * Define a source column. Used when the file does not contain headings. * @param name The name of the column. * @param index The index of the column. * @param colType The column type. * @return The column definition. */ public ColumnDefinition defineSourceColumn(String name, int index, ColumnType colType) { return this.helper.defineSourceColumn(name, index, colType); } /** * @return the helper */ public NormalizationHelper getNormHelper() { return helper; } /** * @param helper * the helper to set */ public void setNormHelper(NormalizationHelper helper) { this.helper = helper; } /** * Divide, and optionally shuffle, the dataset. * @param dataDivisionList The desired divisions. * @param shuffle True, if we should shuffle. * @param rnd Random number generator, often with a specific seed. */ public void divide(List<DataDivision> dataDivisionList, boolean shuffle, GenerateRandom rnd) { if (getData() == null) { throw new EncogError( "Can't divide, data has not yet been generated/normalized."); } PerformDataDivision divide = new PerformDataDivision(shuffle, rnd); divide.perform(dataDivisionList, this, getCalculatedInputSize(), getCalculatedIdealSize()); } /** * Define an output column. * @param col The output column. */ public void defineOutput(ColumnDefinition col) { this.helper.getOutputColumns().add(col); } /** * Define an input column. * @param col The input column. */ public void defineInput(ColumnDefinition col) { this.helper.getInputColumns().add(col); } /** * Define a single column as an output column, all others as inputs. * @param outputColumn The output column. */ public void defineSingleOutputOthersInput(ColumnDefinition outputColumn) { this.helper.clearInputOutput(); for (ColumnDefinition colDef : this.helper.getSourceColumns()) { if (colDef == outputColumn) { defineOutput(colDef); } else if (colDef.getDataType() != ColumnType.ignore) { defineInput(colDef); } } } /** * Define a source column. * @param name The name of the source column. * @param colType The column type. * @return The column definition. */ public ColumnDefinition defineSourceColumn(String name, ColumnType colType) { return this.helper.defineSourceColumn(name, -1, colType); } /** * Define multiple output columns, all others as inputs. * @param outputColumns The output columns. */ public void defineMultipleOutputsOthersInput(ColumnDefinition[] outputColumns) { this.helper.clearInputOutput(); for (ColumnDefinition colDef : this.helper.getSourceColumns()) { boolean isOutput = false; for(ColumnDefinition col : outputColumns) { if( col==colDef) { isOutput = true; } } if ( isOutput) { defineOutput(colDef); } else if (colDef.getDataType() != ColumnType.ignore) { defineInput(colDef); } } } }