/*
* Encog(tm) Java Examples v3.4
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-examples
*
* Copyright 2008-2016 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.examples.proben;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.StringTokenizer;
import org.encog.Encog;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.util.Format;
import org.encog.util.csv.CSVFormat;
import org.encog.util.simple.EncogUtility;
public class ProBenData {
private File sourceFile;
private int boolIn = 0;
private int realIn = 0;
private int boolOut = 0;
private int realOut = 0;
private int trainingExamples = 0;
private int validationExamples = 0;
private int testExamples = 0;
private MLDataSet trainingDataSet;
private MLDataSet validationDataSet;
private MLDataSet testDataSet;
private boolean mergeTest;
public ProBenData(File file, boolean mergeTest) {
this.sourceFile = file;
this.mergeTest = mergeTest;
}
public static String obtainProbenPath(String[] args) {
if (args.length > 0) {
return args[0];
} else {
System.out
.println("To run this program, it is necessary to download the Proben1\n"
+ "datasets and pass their path as the first agrument to this\n"
+ "program. Proben1 can be downloaded\n"
+ "from: https://github.com/jeffheaton/proben1");
System.exit(1);
return null;
}
}
public void processHeaderLine(String line) {
// find the name/value
int index = line.indexOf('=');
String name = line.substring(0,index).trim().toLowerCase();
String value = line.substring(index+1).trim();
int valueInt = Integer.parseInt(value);
// fill in the correct value
if( name.equals("bool_in")) {
this.boolIn = valueInt;
} else if( name.equals("real_in")) {
this.realIn = valueInt;
} else if( name.equals("bool_out")) {
this.boolOut = valueInt;
} else if( name.equals("real_out")) {
this.realOut = valueInt;
} else if( name.equals("training_examples")) {
this.trainingExamples = valueInt;
} else if( name.equals("validation_examples")) {
this.validationExamples = valueInt;
} else if( name.equals("test_examples")) {
this.testExamples = valueInt;
} else {
throw new ProBenError("Unknown header element: " + name);
}
}
public int getInputCount() {
return boolIn+realIn;
}
public int getIdealCount() {
return boolOut+realOut;
}
public void processDataLine(String line) {
MLData inputData = new BasicMLData(getInputCount());
MLData idealData = new BasicMLData(getIdealCount());
StringTokenizer tok = new StringTokenizer(line, " ");
for(int i=0;i<inputData.size();i++) {
inputData.setData(i, Double.parseDouble(tok.nextToken()));
}
for(int i=0;i<idealData.size();i++) {
idealData.setData(i, Double.parseDouble(tok.nextToken()));
}
if( this.trainingDataSet.getRecordCount()<this.trainingExamples) {
this.trainingDataSet.add(inputData, idealData);
} else if( this.validationDataSet.getRecordCount()<this.validationExamples) {
this.validationDataSet.add(inputData, idealData);
} else if( this.testDataSet.getRecordCount()<this.testExamples) {
if( this.mergeTest ) {
this.trainingDataSet.add(inputData, idealData);
} else {
this.testDataSet.add(inputData, idealData);
}
}
}
public void load() {
this.trainingDataSet = new BasicMLDataSet();
this.validationDataSet = new BasicMLDataSet();
this.testDataSet = new BasicMLDataSet();
try {
BufferedReader in = new BufferedReader(new FileReader(
this.sourceFile));
String str;
while ((str = in.readLine()) != null) {
if (str.indexOf('=') != -1) {
processHeaderLine(str);
} else {
processDataLine(str);
}
}
in.close();
} catch (IOException ex) {
throw new ProBenError(ex);
}
}
public int getTrainingExamples() {
return trainingExamples;
}
public MLDataSet getTrainingDataSet() {
return trainingDataSet;
}
public MLDataSet getValidationDataSet() {
return validationDataSet;
}
public MLDataSet getTestDataSet() {
return testDataSet;
}
public String toString() {
StringBuilder result = new StringBuilder();
result.append( "bool_in = " + boolIn + "\n");
result.append( "real_in = " + realIn + "\n");
result.append( "bool_out = " + boolOut + "\n");
result.append( "real_out = " + realOut + "\n");
result.append( "training examples: " + Format.formatInteger((int)trainingDataSet.getRecordCount())+"\n");
result.append( "validation examples: " + Format.formatInteger((int)validationDataSet.getRecordCount())+"\n");
result.append( "test examples: " + Format.formatInteger((int)testDataSet.getRecordCount())+"\n");
return result.toString();
}
public String getName() {
return this.sourceFile.getName();
}
private void center(MLDataSet dataset, double inputCenter, double outputCenter) {
if( dataset.size()==0 ) {
return;
}
// Calculate MEAN
double[] mean = new double[this.getInputCount()+this.getIdealCount()];
int count = 0;
for(MLDataPair pair: dataset) {
count++;
int meanIndex = 0;
for(int i=0;i<getInputCount();i++) {
mean[meanIndex++]+=pair.getInput().getData(i);
}
for(int i=0;i<getIdealCount();i++) {
mean[meanIndex++]+=pair.getIdeal().getData(i);
}
}
for(int i=0;i<mean.length;i++) {
mean[i]/=count;
}
// Calculate the variance (on the way to standard deviation)
double[] sdev = new double[this.getInputCount()+this.getIdealCount()];
for(MLDataPair pair: dataset) {
int varIndex = 0;
for(int i=0;i<getInputCount();i++) {
sdev[varIndex]+=Math.pow(mean[varIndex]-pair.getInput().getData(i),2);
varIndex++;
}
for(int i=0;i<getIdealCount();i++) {
sdev[varIndex]+=Math.pow(mean[varIndex]-pair.getIdeal().getData(i),2);
varIndex++;
}
}
// Take square root of variance, and get standard deviation
for(int i=0;i<sdev.length;i++) {
sdev[i]=Math.sqrt(sdev[i]);
}
// Now use the zscore, centered at requested value
for(MLDataPair pair: dataset) {
int index = 0;
for(int i=0;i<getInputCount();i++) {
double zscore = 0;
// If zscore is undefined (zero variance) then just use zero so that this
// value is at the center.
if(sdev[i]>Encog.DEFAULT_DOUBLE_EQUAL) {
zscore = (pair.getInput().getData(i) - mean[index])/sdev[i];
}
pair.getInput().setData(i, inputCenter + zscore);
index++;
}
for(int i=0;i<getIdealCount();i++) {
double zscore = 0;
// If zscore is undefined (zero variance) then just use zero so that this
// value is at the center.
if(sdev[i]>Encog.DEFAULT_DOUBLE_EQUAL) {
zscore = (pair.getIdeal().getData(i) - mean[index])/sdev[i];
}
pair.getIdeal().setData(i, outputCenter + zscore);
index++;
}
}
}
public void center(double inputCenter, double outputCenter) {
center(this.getTrainingDataSet(),inputCenter,outputCenter);
center(this.getValidationDataSet(),inputCenter,outputCenter);
center(this.getTestDataSet(),inputCenter,outputCenter);
}
}