/*
* Encog(tm) Java Examples v3.4
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-examples
*
* Copyright 2008-2016 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.examples.guide.timeseries;
import java.io.File;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.Arrays;
import org.encog.ConsoleStatusReportable;
import org.encog.Encog;
import org.encog.bot.BotUtil;
import org.encog.mathutil.error.ErrorCalculation;
import org.encog.mathutil.error.ErrorCalculationMode;
import org.encog.ml.MLRegression;
import org.encog.ml.data.MLData;
import org.encog.ml.data.versatile.NormalizationHelper;
import org.encog.ml.data.versatile.VersatileMLDataSet;
import org.encog.ml.data.versatile.columns.ColumnDefinition;
import org.encog.ml.data.versatile.columns.ColumnType;
import org.encog.ml.data.versatile.sources.CSVDataSource;
import org.encog.ml.data.versatile.sources.VersatileDataSource;
import org.encog.ml.factory.MLMethodFactory;
import org.encog.ml.model.EncogModel;
import org.encog.util.arrayutil.VectorWindow;
import org.encog.util.csv.CSVFormat;
import org.encog.util.csv.ReadCSV;
public class SunSpotTimeseries {
public static String DATA_URL = "http://solarscience.msfc.nasa.gov/greenwch/spot_num.txt";
private String tempPath;
public static final int WINDOW_SIZE = 3;
public File downloadData(String[] args) throws MalformedURLException {
if (args.length != 0) {
tempPath = args[0];
} else {
tempPath = System.getProperty("java.io.tmpdir");
}
File filename = new File(tempPath, "auto-mpg.data");
BotUtil.downloadPage(new URL(SunSpotTimeseries.DATA_URL), filename);
System.out.println("Downloading sunspot dataset to: " + filename);
return filename;
}
public void run(String[] args) {
try {
ErrorCalculation.setMode(ErrorCalculationMode.RMS);
// Download the data that we will attempt to model.
File filename = downloadData(args);
// Define the format of the data file.
// This area will change, depending on the columns and
// format of the file that you are trying to model.
CSVFormat format = new CSVFormat('.', ' '); // decimal point and
// space separated
VersatileDataSource source = new CSVDataSource(filename, true,
format);
VersatileMLDataSet data = new VersatileMLDataSet(source);
data.getNormHelper().setFormat(format);
ColumnDefinition columnSSN = data.defineSourceColumn("SSN",
ColumnType.continuous);
ColumnDefinition columnDEV = data.defineSourceColumn("DEV",
ColumnType.continuous);
// Analyze the data, determine the min/max/mean/sd of every column.
data.analyze();
// Use SSN & DEV to predict SSN. For time-series it is okay to have
// SSN both as
// an input and an output.
data.defineInput(columnSSN);
data.defineInput(columnDEV);
data.defineOutput(columnSSN);
// Create feedforward neural network as the model type.
// MLMethodFactory.TYPE_FEEDFORWARD.
// You could also other model types, such as:
// MLMethodFactory.SVM: Support Vector Machine (SVM)
// MLMethodFactory.TYPE_RBFNETWORK: RBF Neural Network
// MLMethodFactor.TYPE_NEAT: NEAT Neural Network
// MLMethodFactor.TYPE_PNN: Probabilistic Neural Network
EncogModel model = new EncogModel(data);
model.selectMethod(data, MLMethodFactory.TYPE_FEEDFORWARD);
// Send any output to the console.
model.setReport(new ConsoleStatusReportable());
// Now normalize the data. Encog will automatically determine the
// correct normalization
// type based on the model you chose in the last step.
data.normalize();
// Set time series.
data.setLeadWindowSize(1);
data.setLagWindowSize(WINDOW_SIZE);
// Hold back some data for a final validation.
// Do not shuffle the data into a random ordering. (never shuffle
// time series)
// Use a seed of 1001 so that we always use the same holdback and
// will get more consistent results.
model.holdBackValidation(0.3, false, 1001);
// Choose whatever is the default training type for this model.
model.selectTrainingType(data);
// Use a 5-fold cross-validated train. Return the best method found.
// (never shuffle time series)
MLRegression bestMethod = (MLRegression) model.crossvalidate(5,
false);
// Display the training and validation errors.
System.out.println("Training error: "
+ model.calculateError(bestMethod,
model.getTrainingDataset()));
System.out.println("Validation error: "
+ model.calculateError(bestMethod,
model.getValidationDataset()));
// Display our normalization parameters.
NormalizationHelper helper = data.getNormHelper();
System.out.println(helper.toString());
// Display the final model.
System.out.println("Final model: " + bestMethod);
// Loop over the entire, original, dataset and feed it through the
// model. This also shows how you would process new data, that was
// not part of your training set. You do not need to retrain, simply
// use the NormalizationHelper class. After you train, you can save
// the NormalizationHelper to later normalize and denormalize your
// data.
ReadCSV csv = new ReadCSV(filename, true, format);
String[] line = new String[2];
// Create a vector to hold each time-slice, as we build them.
// These will be grouped together into windows.
double[] slice = new double[2];
VectorWindow window = new VectorWindow(WINDOW_SIZE + 1);
MLData input = helper.allocateInputVector(WINDOW_SIZE + 1);
// Only display the first 100
int stopAfter = 100;
while (csv.next() && stopAfter > 0) {
StringBuilder result = new StringBuilder();
line[0] = csv.get(2);// ssn
line[1] = csv.get(3);// dev
helper.normalizeInputVector(line, slice, false);
// enough data to build a full window?
if (window.isReady()) {
window.copyWindow(input.getData(), 0);
String correct = csv.get(2); // trying to predict SSN.
MLData output = bestMethod.compute(input);
String predicted = helper
.denormalizeOutputVectorToString(output)[0];
result.append(Arrays.toString(line));
result.append(" -> predicted: ");
result.append(predicted);
result.append("(correct: ");
result.append(correct);
result.append(")");
System.out.println(result.toString());
}
// Add the normalized slice to the window. We do this just after
// the after checking to see if the window is ready so that the
// window is always one behind the current row. This is because
// we are trying to predict next row.
window.add(slice);
stopAfter--;
}
// Delete data file and shut down.
filename.delete();
Encog.getInstance().shutdown();
} catch (Exception ex) {
ex.printStackTrace();
}
}
public static void main(String[] args) {
SunSpotTimeseries prg = new SunSpotTimeseries();
prg.run(args);
}
}