/* * Encog(tm) Java Examples v3.4 * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-examples * * Copyright 2008-2016 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.examples.neural.predict.sunspot; import java.io.File; import java.net.MalformedURLException; import java.net.URL; import org.encog.Encog; import org.encog.bot.BotUtil; import org.encog.ml.MLMethod; import org.encog.ml.MLRegression; import org.encog.ml.MLResettable; import org.encog.ml.data.MLData; import org.encog.ml.data.MLDataSet; import org.encog.ml.data.temporal.TemporalDataDescription; import org.encog.ml.data.temporal.TemporalMLDataSet; import org.encog.ml.data.temporal.TemporalPoint; import org.encog.ml.factory.MLMethodFactory; import org.encog.ml.factory.MLTrainFactory; import org.encog.ml.train.MLTrain; import org.encog.ml.train.strategy.RequiredImprovementStrategy; import org.encog.neural.networks.training.propagation.manhattan.ManhattanPropagation; import org.encog.util.arrayutil.NormalizationAction; import org.encog.util.arrayutil.NormalizedField; import org.encog.util.csv.ReadCSV; import org.encog.util.simple.EncogUtility; /** * This example is meant to demonstrate how to use the Encog TemporalMLDataSet through the * full cycle of building a training set, training a model, and then performing predictions. * It also demonstrates how to use multiple values in a time-series with normalization and * denormalization. The point of this example is more to show how to structure such a prediction * problem, than how to accurately predict sun spot data. This example does get marginally good * predictive results, however, sun spot data tends to be very noisy. * * This example was presented as an Encog FAQ, for more FAQ's see: * * http://www.heatonresearch.com/faq */ public class MultiSunspot { /** * Set this to whatever you want to use as your home directory. * The example is set to use the current directory. */ public static final File MYDIR = new File("."); /** * This is the amount of data to use to predict. */ public static final int INPUT_WINDOW_SIZE = 12; /** * This is the amount of data to actually predict. */ public static final int PREDICT_WINDOW_SIZE = 1; /** * Used to normalize the SSN (sun spot number) from a range of 0-300 * to 0-1. */ public static NormalizedField normSSN = new NormalizedField( NormalizationAction.Normalize, "ssn", 300, 0, 1, 0); /** * Used to normalize the dev from a range of 0-100 * to 0-1. */ public static NormalizedField normDEV = new NormalizedField( NormalizationAction.Normalize, "dev", 100, 0, 1, 0); public static TemporalMLDataSet initDataSet() { // create a temporal data set TemporalMLDataSet dataSet = new TemporalMLDataSet(INPUT_WINDOW_SIZE, PREDICT_WINDOW_SIZE); // we are dealing with two columns. // The first is the sunspot number. This is both an input (used to // predict) and an output (we want to predict it), so true,true. TemporalDataDescription sunSpotNumberDesc = new TemporalDataDescription(TemporalDataDescription.Type.RAW, true, true); // The second is the standard deviation for the month. This is an // input (used to predict) only, so true,false. TemporalDataDescription standardDevDesc = new TemporalDataDescription(TemporalDataDescription.Type.RAW, true, false); dataSet.addDescription(sunSpotNumberDesc); dataSet.addDescription(standardDevDesc); return dataSet; } /** * Create and train a model. Use Encog factory codes to specify the model type that you want. * @param trainingData The training data to use. * @param methodName The name of the machine learning method (or model). * @param methodArchitecture The type of architecture to use with that model. * @param trainerName The type of training. * @param trainerArgs Training arguments. * @return The trained model. */ public static MLRegression trainModel( MLDataSet trainingData, String methodName, String methodArchitecture, String trainerName, String trainerArgs) { // first, create the machine learning method (the model) MLMethodFactory methodFactory = new MLMethodFactory(); MLMethod method = methodFactory.create(methodName, methodArchitecture, trainingData.getInputSize(), trainingData.getIdealSize()); // second, create the trainer MLTrainFactory trainFactory = new MLTrainFactory(); MLTrain train = trainFactory.create(method,trainingData,trainerName,trainerArgs); // reset if improve is less than 1% over 5 cycles if( method instanceof MLResettable && !(train instanceof ManhattanPropagation) ) { train.addStrategy(new RequiredImprovementStrategy(500)); } // third train the model EncogUtility.trainToError(train, 0.002); return (MLRegression)train.getMethod(); } /** * Download the sun spot data from NASA. * @return The path downloaded to. * @throws MalformedURLException */ public static File downloadSunSpotData() throws MalformedURLException { File rawFile = new File(MYDIR, "sunspots.csv"); // Step 1. Download sunspot data from NASA. if (rawFile.exists()) { System.out.println("Data already downloaded to: " + rawFile.getPath()); } else { System.out.println("Downloading sunspot data to: " + rawFile.getPath()); BotUtil.downloadPage( new URL( "http://solarscience.msfc.nasa.gov/greenwch/spot_num.txt"), rawFile); } return rawFile; } public static TemporalMLDataSet createTraining(File rawFile) { TemporalMLDataSet trainingData = initDataSet(); ReadCSV csv = new ReadCSV(rawFile.toString(), true, ' '); while (csv.next()) { int year = csv.getInt(0); int month = csv.getInt(1); double sunSpotNum = csv.getDouble(2); double dev = csv.getDouble(3); // we need a sequence number to sort the data. Here we just use // year * 100 + month, which produces output like "201301" for // January, 2013. int sequenceNumber = (year * 100) + month; TemporalPoint point = new TemporalPoint(trainingData .getDescriptions().size()); point.setSequence(sequenceNumber); point.setData(0, normSSN.normalize(sunSpotNum) ); point.setData(1, normDEV.normalize(dev) ); trainingData.getPoints().add(point); } csv.close(); // generate the time-boxed data trainingData.generate(); return trainingData; } public static TemporalMLDataSet predict(File rawFile, MLRegression model) { // You can also use the TemporalMLDataSet for prediction. We will not use "generate" // as we do not want to generate an entire training set. Rather we pass it each sun spot // ssn and dev and it will produce the input to the model, once there is enough data. TemporalMLDataSet trainingData = initDataSet(); ReadCSV csv = new ReadCSV(rawFile.toString(), true, ' '); while (csv.next()) { int year = csv.getInt(0); int month = csv.getInt(1); double sunSpotNum = csv.getDouble(2); double dev = csv.getDouble(3); // do we have enough data for a prediction yet? if( trainingData.getPoints().size()>=trainingData.getInputWindowSize() ) { // Make sure to use index 1, because the temporal data set is always one ahead // of the time slice its encoding. So for RAW data we are really encoding 0. MLData modelInput = trainingData.generateInputNeuralData(1); MLData modelOutput = model.compute(modelInput); double ssn = normSSN.deNormalize(modelOutput.getData(0)); System.out.println(year + ":Predicted=" + ssn + ",Actual=" + sunSpotNum ); // Remove the earliest training element. Unlike when we produced training data, // we do not want to build up a large data set. We just add enough data points to produce // input to the model. trainingData.getPoints().remove(0); } // we need a sequence number to sort the data. Here we just use // year * 100 + month, which produces output like "201301" for // January, 2013. int sequenceNumber = (year * 100) + month; TemporalPoint point = new TemporalPoint(trainingData.getDescriptions().size()); point.setSequence(sequenceNumber); point.setData(0, normSSN.normalize(sunSpotNum) ); point.setData(1, normDEV.normalize(dev) ); trainingData.getPoints().add(point); } csv.close(); // generate the time-boxed data trainingData.generate(); return trainingData; } /** * The main method. * @param args The arguments. */ public static void main(String[] args) { try { // Step 1. Download sun spot data from NASA. File rawFile = downloadSunSpotData(); // Step 2. Create training data TemporalMLDataSet trainingData = createTraining(rawFile); // Step 3. Create and train the model. // All sorts of models can be used here, see the XORFactory // example for more info. MLRegression model = trainModel( trainingData, MLMethodFactory.TYPE_FEEDFORWARD, "?:B->SIGMOID->25:B->SIGMOID->?", MLTrainFactory.TYPE_RPROP, ""); // Now predict predict(rawFile,model); Encog.getInstance().shutdown(); } catch (Exception ex) { ex.printStackTrace(); } } }