/***********************************************************************
This file is part of KEEL-software, the Data Mining tool for regression,
classification, clustering, pattern mining and so on.
Copyright (C) 2004-2010
F. Herrera (herrera@decsai.ugr.es)
L. S�nchez (luciano@uniovi.es)
J. Alcal�-Fdez (jalcala@decsai.ugr.es)
S. Garc�a (sglopez@ujaen.es)
A. Fern�ndez (alberto.fernandez@ujaen.es)
J. Luengo (julianlm@decsai.ugr.es)
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see http://www.gnu.org/licenses/
**********************************************************************/
package keel.Algorithms.Neural_Networks.gmdh;
import java.io.BufferedWriter;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import keel.Dataset.Attributes;
/**
* <p>
* Class for the algorithm sonn
* </p>
* @author Written by Nicolas Garcia Pedrajas (University of Cordoba) 27/02/2007
* @version 0.1
* @since JDK1.5
*/
public class sonn {
double T;
node nodes[];
int n_nodes;
int output; // Output node. It is the node with the smallest SEC
/**
* <p>
* Constructor
* </p>
* @param global Global parameters
* @param data Input data
*/
public sonn(SetupParameters global, Data data) {
double old_sec, new_sec;
nodes = new node[global.max_nodes];
// Initialize T and S
n_nodes = 0;
T = global.To;
// Generate input nodes
for (int i = 0; i < global.Ninputs; i++) {
NewBasicNode();
}
do {
int new_nodes = 0;
do {
old_sec = StateEnergy(global);
// Sj = generate (Si)
NewRandomNode(global, data);
new_sec = StateEnergy(global);
if (Accept(global, new_sec, old_sec) == false) {
// Delete last node
n_nodes--;
}
else {
new_nodes++;
nodes[nodes[n_nodes - 1].terminal[0]].front_node = false;
nodes[nodes[n_nodes - 1].terminal[1]].front_node = false;
}
}
while (new_nodes <= global.omega && n_nodes < global.max_nodes);
// Number of new nodes exceds limit
T *= global.alpha; // decrease temperature T
}
while (T > global.Tend && n_nodes < global.max_nodes); // Temperature is below Tend
}
/**
* <p>
* Adds a new node
* </p>
*/
public void NewBasicNode() {
node new_node = new node();
new_node.basic_node = true;
new_node.terminal[0] = n_nodes;
new_node.SEC = 1.0e20;
new_node.k = 1.0;
nodes[n_nodes] = new_node;
n_nodes++;
}
/**
* Acept or not a new SEC
* @param global Global parameters
* @param sec_new New Sec value
* @param sec_old Old Sec value
* @return boolean indicating if the sec has been accepted
*/
private boolean Accept(Parameters global, double sec_new, double sec_old) {
if (sec_new < sec_old) {
return true;
}
else {
double p = Math.exp( - (sec_new - sec_old) / T);
if (Genesis.frandom(0.0, 1.0) < p) {
return true;
}
else {
return false;
}
}
}
/**
* <p>
* Returns the minimum SEC of all the front nodes of the state
* </p>
* @param global Global parameters
*/
private double StateEnergy(SetupParameters global) {
double min_sec = 1.0e20;
for (int i = global.Ninputs; i < n_nodes; i++) {
if (nodes[i].SEC < min_sec && nodes[i].front_node) {
min_sec = nodes[i].SEC;
output = i;
}
}
return min_sec;
}
/**
* <p>
* Creates a new random node
* </p>
* @param global Global parameters
* @param data Input data
*/
public void NewRandomNode(SetupParameters global, Data data) {
node next, best;
double sec, covar[][], alpha[][], chisq[], alamda[], x[][], y[],
ochisq[], atry[], beta[], da[], oneda[][], sd[], min, old_chisq;
int mfit, sw[];
boolean sing;
// Create new node
next = new node();
best = new node();
next.front_node = true;
ochisq = new double[1];
chisq = new double[1];
alamda = new double[1];
best = new node();
sw = new int[node.TERMS];
covar = new double[node.TERMS][node.TERMS];
alpha = new double[node.TERMS][node.TERMS];
x = new double[global.n_train_patterns][2];
y = new double[global.n_train_patterns];
atry = new double[node.TERMS];
beta = new double[node.TERMS];
da = new double[node.TERMS];
oneda = new double[node.TERMS][1];
sd = new double[global.n_train_patterns];
for (int i = 0; i < global.n_train_patterns; i++) {
sd[i] = 1.0;
}
/////// Creation of inputs
// Test if inputs are constant
double variance[] = new double[2];
double mean[] = new double[2];
double sosq[] = new double[2];
do {
mean[0] = mean[1] = sosq[0] = sosq[1] = 0.0;
// Randomly select two terminals
int one = Genesis.irandom( 0, n_nodes);
int two = Genesis.irandom( 0, n_nodes);
// If the two terminals form another front node destroy the old node
// Not implemented yet
next.terminal[0] = one;
next.terminal[1] = two;
// Set desired output and x1 and x2 inputs
for (int i = 0; i < global.n_train_patterns; i++) {
y[i] = data.train[i][global.Ninputs];
x[i][0] = nodes[one].NodeOutput(data.train[i], nodes);
x[i][1] = nodes[two].NodeOutput(data.train[i], nodes);
mean[0] += x[i][0];
mean[1] += x[i][1];
sosq[0] += x[i][0] * x[i][0];
sosq[1] += x[i][1] * x[i][1];
}
mean[0] /= global.n_train_patterns;
mean[1] /= global.n_train_patterns;
variance[0] = sosq[0] / global.n_train_patterns -
mean[0] * mean[0];
variance[1] = sosq[1] / global.n_train_patterns -
mean[1] * mean[1];
}
while (variance[0] < 0.00001 || variance[1] < 0.00001);
// For each prototype surface in F
min = 1e20;
for (int i = 1; i <= 4; i++) {
// next.type = i;
for (int j = 0; j < node.TERMS; j++) {
next.a[j] = Genesis.frandom(-global.aRange, global.aRange);
}
// Fit the surface
switch (i) {
case 1:
// Surfaces y = a0 + a1x1 + a2x2
sw[0] = sw[1] = sw[2] = 1;
sw[3] = sw[4] = sw[5] = 0;
next.a[3] = next.a[4] = next.a[5] = 0.0;
mfit = 3;
break;
case 2:
// Surface y = a0 + a1x1 + a2x2 + a3x1x2
sw[0] = sw[1] = sw[2] = sw[3] = 1;
sw[4] = sw[5] = 0;
next.a[4] = next.a[5] = 0.0;
mfit = 4;
break;
case 3:
// Surface y = a0 + a1x1 + a2x1**2
sw[0] = sw[1] = sw[4] = 1;
sw[2] = sw[3] = sw[5] = 0;
next.a[2] = next.a[3] = next.a[5] = 0.0;
mfit = 3;
break;
case 4:
// Surface y = a0 + a1x1 + a2x2 + a3x1x2 + a4x1**2 + a5x2**2
sw[0] = sw[1] = sw[2] = sw[3] = sw[4] = sw[5] = 1;
mfit = 6;
break;
default:
mfit = 0;
break;
}
// Levenberg - Marquardt algorithm
alamda[0] = -1.0;
ochisq[0] = chisq[0] = 1.0e20;
int ileven = 10; // Maximum number of Levenberg iterations.
do { // Repeat till convergence
old_chisq = chisq[0];
sing = LM.mrqmin(x, y, sd, global.n_train_patterns, next.a, sw,
node.TERMS,
covar,
alpha, chisq, alamda, mfit, ochisq, atry, beta, da,
oneda, global);
if (sing)
ileven--;
}
while ( (old_chisq - chisq[0]) > global.LM_convergence && ileven > 0);
/*alamda[0] = 0.0;
LM.mrqmin(x, y, sd, global.n_train_patterns, next.a, sw, node.TERMS,
covar,
alpha, chisq, alamda, mfit, ochisq, atry, beta, da, oneda);*/
sec = next.StructureEstimationCriterion(nodes, data, global);
// Choose the surface with smallest SEC
if (min > sec) {
min = sec;
next.CopyTo(best);
}
}
/* Test if all the coefficients are 0
boolean all_0 = true;
for (int i = 0; i < node.TERMS; i++) {
if (best.a[i] != 0.0) {
all_0 = false;
}
}*/
// Construct the node using the prototype surface chosen if the SEC of the node
// is smaller than the SEC of parents
if (best.SEC < nodes[best.terminal[0]].SEC &&
best.SEC < nodes[best.terminal[1]].SEC /*&& !all_0 &&
best.SEC < 1e20*/) {
// Add new node
nodes[n_nodes] = best;
n_nodes++;
if (global.verbose) {
System.err.println("Added node " + n_nodes);
}
}
}
/**
* <p>
* Saves the network to a file, including the seed
* </p>
* @param file_name The name of the file
* @param seed Random seed
* @param append Boolean for appending or replacing the file
* @throws IOException
*/
public void SaveNetwork(String file_name, long seed, boolean append) throws IOException {
String line;
try {
// Result file
FileOutputStream file = new FileOutputStream(file_name, append);
BufferedWriter f = new BufferedWriter(new OutputStreamWriter(file));
f.write("Random seed: " + seed);
f.newLine();
// For all nodes
for (int i = 0; i < n_nodes; i++) {
if (nodes[i].basic_node) {
f.write("y(" + Integer.toString(i) + ") = x(" + Integer.toString(i) +
")");
}
else {
f.write("y(" + Integer.toString(i) + ") = ");
if (nodes[i].a[0] != 0.0) {
f.write(new PrintfFormat("%6.4g ").sprintf(nodes[i].a[0]));
}
if (nodes[i].a[1] != 0.0) {
f.write(new PrintfFormat("%+6.4g").sprintf(nodes[i].a[1]));
f.write(" y(" +
Integer.toString(nodes[i].terminal[0]) + ") ");
}
if (nodes[i].a[2] != 0.0) {
f.write(new PrintfFormat("%+6.4g").sprintf(nodes[i].a[2]));
f.write(" y(" +
Integer.toString(nodes[i].terminal[1]) + ") ");
}
if (nodes[i].a[3] != 0.0) {
f.write(new PrintfFormat("%+6.4g").sprintf(nodes[i].a[3]));
f.write(" y(" +
Integer.toString(nodes[i].terminal[0]) + ") y(" +
Integer.toString(nodes[i].terminal[1]) + ") ");
}
if (nodes[i].a[4] != 0.0) {
f.write(new PrintfFormat("%+6.4g").sprintf(nodes[i].a[4]));
f.write(" y(" +
Integer.toString(nodes[i].terminal[0]) + ")^2");
}
if (nodes[i].a[5] != 0.0) {
f.write(new PrintfFormat("%+6.4g").sprintf(nodes[i].a[5]));
f.write(" y(" +
Integer.toString(nodes[i].terminal[1]) + ")^2");
}
}
f.newLine();
}
f.write("Output node: y(" + output + ")\n");
f.close();
file.close();
}
catch (FileNotFoundException e) {
System.err.println("Cannot created output file");
}
}
/**
* <p>
* Obtains the fitness of the sonn
* </p>
* @param input Input value
* @return fitness value
*/
public double GenerateOutput(double input[]) {
return nodes[output].NodeOutput(input, nodes);
}
public double TestSONNInRegression(SetupParameters global, double data[][],
int npatterns) {
double fitness, RMS = 0.0, error, out;
for (int i = 0; i < npatterns; i++) {
// Obtain network output
out = GenerateOutput(data[i]);
// Obtain RMS error
error = Math.pow(out - data[i][global.Ninputs], 2.0);
RMS += Math.sqrt(error);
}
fitness = RMS / (npatterns * global.Noutputs);
return fitness;
}
/**
* <p>
* Obtains fitness for a classification problem
* </p>
* @param global Global parameters
* @param data Input data
* @param npatterns Number of patterns
* @return fitness
*/
public double TestSONNInClassification(SetupParameters global, double data[][],
int npatterns) {
double ok = 0.0;
double fitness, out;
for (int i = 0; i < npatterns; i++) {
// Obtain network output
out = GenerateOutput(data[i]);
if ( (data[i][global.Ninputs] == 0 && out < 0.5) ||
(data[i][global.Ninputs] == 1 && out > 0.5)) {
ok++;
}
}
fitness = ok / npatterns;
return fitness;
}
/**
* <p>
* Saves the output to a file
* </p<
* @param file_name Name of the file
* @param data Input data
* @param n Number of patterns
* @param global Global parameters
* @throws IOException
*/
public void SaveOutputFile(String file_name, double data[][], int n,
SetupParameters global) throws
IOException {
String line;
double out;
try {
// Result file
FileOutputStream file = new FileOutputStream(file_name);
BufferedWriter f = new BufferedWriter(new OutputStreamWriter(file));
// File header
f.write("@relation "+Attributes.getRelationName()+"\n");
f.write(Attributes.getInputAttributesHeader());
f.write(Attributes.getOutputAttributesHeader());
f.write(Attributes.getInputHeader()+"\n");
f.write(Attributes.getOutputHeader()+"\n");
f.write("@data\n");
// For all patterns
for (int i = 0; i < n; i++) {
// Classification
if (global.problem.compareToIgnoreCase("Classification") == 0) {
// Obtain network output
out = GenerateOutput(data[i]);
/* Original output using numbers (Deprecated)
f.write(Integer.toString( (int) data[i][global.Ninputs]) + " ");
if (out < 0.5) {
f.write(Integer.toString(0));
}
else {
f.write(Integer.toString(1));
}
*/
// Current output using nominal values.
f.write(Attributes.getOutputAttributes()[0].getNominalValue((int) data[i][global.Ninputs]) + " ");
if (out < 0.5) {
f.write(Attributes.getOutputAttributes()[0].getNominalValue(0));
}
else {
f.write(Attributes.getOutputAttributes()[0].getNominalValue(1));
}
}
// Regression
else {
f.write(Double.toString(data[i][global.Ninputs]) + " ");
out = GenerateOutput(data[i]);
f.write(Double.toString(out));
}
f.newLine();
}
f.close();
file.close();
}
catch (FileNotFoundException e) {
System.err.println("Cannot created output file");
}
}
}