/* * Encog(tm) Java Examples v3.4 * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-examples * * Copyright 2008-2016 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.examples.guide.regression; import java.io.File; import java.net.MalformedURLException; import java.net.URL; import java.util.Arrays; import org.encog.ConsoleStatusReportable; import org.encog.Encog; import org.encog.bot.BotUtil; import org.encog.ml.MLRegression; import org.encog.ml.data.MLData; import org.encog.ml.data.versatile.NormalizationHelper; import org.encog.ml.data.versatile.VersatileMLDataSet; import org.encog.ml.data.versatile.columns.ColumnDefinition; import org.encog.ml.data.versatile.columns.ColumnType; import org.encog.ml.data.versatile.missing.MeanMissingHandler; import org.encog.ml.data.versatile.sources.CSVDataSource; import org.encog.ml.data.versatile.sources.VersatileDataSource; import org.encog.ml.factory.MLMethodFactory; import org.encog.ml.model.EncogModel; import org.encog.util.csv.CSVFormat; import org.encog.util.csv.ReadCSV; public class AutoMPGRegression { public static String DATA_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data"; private String tempPath; public File downloadData(String[] args) throws MalformedURLException { if (args.length != 0) { tempPath = args[0]; } else { tempPath = System.getProperty("java.io.tmpdir"); } File mpgFile = new File(tempPath, "auto-mpg.data"); BotUtil.downloadPage(new URL(AutoMPGRegression.DATA_URL), mpgFile); System.out.println("Downloading auto-mpg dataset to: " + mpgFile); return mpgFile; } public void run(String[] args) { try { // Download the data that we will attempt to model. File filename = downloadData(args); // Define the format of the data file. // This area will change, depending on the columns and // format of the file that you are trying to model. CSVFormat format = new CSVFormat('.',' '); // decimal point and space separated VersatileDataSource source = new CSVDataSource(filename, false, format); VersatileMLDataSet data = new VersatileMLDataSet(source); data.getNormHelper().setFormat(format); ColumnDefinition columnMPG = data.defineSourceColumn("mpg", 0, ColumnType.continuous); ColumnDefinition columnCylinders = data.defineSourceColumn("cylinders", 1, ColumnType.ordinal); // It is very important to predefine ordinals, so that the order is known. columnCylinders.defineClass(new String[] {"3","4","5","6","8"}); data.defineSourceColumn("displacement", 2,ColumnType.continuous); ColumnDefinition columnHorsePower = data.defineSourceColumn("horsepower", 3, ColumnType.continuous); data.defineSourceColumn("weight", 4, ColumnType.continuous); data.defineSourceColumn("acceleration", 5, ColumnType.continuous); ColumnDefinition columnModelYear = data.defineSourceColumn("model_year", 6, ColumnType.ordinal); columnModelYear.defineClass(new String[] {"70","71","72","73","74","75","76","77","78","79","80","81","82"}); data.defineSourceColumn("origin", 7, ColumnType.nominal); // Define how missing values are represented. data.getNormHelper().defineUnknownValue("?"); data.getNormHelper().defineMissingHandler(columnHorsePower, new MeanMissingHandler()); // Analyze the data, determine the min/max/mean/sd of every column. data.analyze(); // Map the prediction column to the output of the model, and all // other columns to the input. data.defineSingleOutputOthersInput(columnMPG); // Create feedforward neural network as the model type. MLMethodFactory.TYPE_FEEDFORWARD. // You could also other model types, such as: // MLMethodFactory.SVM: Support Vector Machine (SVM) // MLMethodFactory.TYPE_RBFNETWORK: RBF Neural Network // MLMethodFactor.TYPE_NEAT: NEAT Neural Network // MLMethodFactor.TYPE_PNN: Probabilistic Neural Network EncogModel model = new EncogModel(data); model.selectMethod(data, MLMethodFactory.TYPE_FEEDFORWARD); // Send any output to the console. model.setReport(new ConsoleStatusReportable()); // Now normalize the data. Encog will automatically determine the correct normalization // type based on the model you chose in the last step. data.normalize(); // Hold back some data for a final validation. // Shuffle the data into a random ordering. // Use a seed of 1001 so that we always use the same holdback and will get more consistent results. model.holdBackValidation(0.3, true, 1001); // Choose whatever is the default training type for this model. model.selectTrainingType(data); // Use a 5-fold cross-validated train. Return the best method found. MLRegression bestMethod = (MLRegression)model.crossvalidate(5, true); // Display the training and validation errors. System.out.println( "Training error: " + model.calculateError(bestMethod, model.getTrainingDataset())); System.out.println( "Validation error: " + model.calculateError(bestMethod, model.getValidationDataset())); // Display our normalization parameters. NormalizationHelper helper = data.getNormHelper(); System.out.println(helper.toString()); // Display the final model. System.out.println("Final model: " + bestMethod); // Loop over the entire, original, dataset and feed it through the model. // This also shows how you would process new data, that was not part of your // training set. You do not need to retrain, simply use the NormalizationHelper // class. After you train, you can save the NormalizationHelper to later // normalize and denormalize your data. ReadCSV csv = new ReadCSV(filename, false, format); String[] line = new String[7]; MLData input = helper.allocateInputVector(); while(csv.next()) { StringBuilder result = new StringBuilder(); line[0] = csv.get(1); line[1] = csv.get(2); line[2] = csv.get(3); line[3] = csv.get(4); line[4] = csv.get(5); line[5] = csv.get(6); line[6] = csv.get(7); String correct = csv.get(0); helper.normalizeInputVector(line,input.getData(),false); MLData output = bestMethod.compute(input); String predictedMPG = helper.denormalizeOutputVectorToString(output)[0]; result.append(Arrays.toString(line)); result.append(" -> predicted: "); result.append(predictedMPG); result.append("(correct: "); result.append(correct); result.append(")"); System.out.println(result.toString()); } // Delete data file and shut down. filename.delete(); Encog.getInstance().shutdown(); } catch (Exception ex) { ex.printStackTrace(); } } public static void main(String[] args) { AutoMPGRegression prg = new AutoMPGRegression(); prg.run(args); } }