/*
* RapidMiner
*
* Copyright (C) 2001-2008 by Rapid-I and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapid-i.com
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.tools.math.som;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Random;
/**
* This class can be used to train a Kohonen net.
*
* @author Sebastian Land
* @version $Id: KohonenNet.java,v 1.7 2008/07/18 15:50:46 ingomierswa Exp $
*/
public class KohonenNet implements Serializable {
private static final long serialVersionUID = -5445606750204819559L;
private long randomSeed = 19091982;
private int netDimension;
private int[] netDimensions;
private int phase;
private int trainingSteps = 80;
private ArrayList<ProgressListener> progressListener = new ArrayList<ProgressListener>();
private KohonenNode[] nodes;
private DistanceFunction distanceFunction;
private AdaptationFunction adaptationFunction;
private KohonenTrainingsData data;
private Random randomGenerator = new Random(randomSeed);
private int cubeNodeCounter = 0;
private int cubeEdgeLength = 0;
private int[] cubeEdgeLengths;
private int[] cubeOffset;
public KohonenNet(KohonenTrainingsData data) {
this.distanceFunction = new EuclideanDistance();
this.adaptationFunction = new RitterAdaptation();
this.data = data;
}
public void init(int dataDimension, int[] netDimensions, boolean hexagonal) {
// TODO
// this.dataDimension = dataDimension;
// if (netDimensions.length == 2) {
// this.hexagonal = hexagonal;
// } else if (netDimensions.length > 2) {
// this.hexagonal = false;
// }
updateProgressListener(0);
this.netDimension = netDimensions.length;
this.netDimensions = netDimensions;
// Calculats needed number of nodes
int nodeNumber = 1;
for (int i = 0; i < netDimension; i++) {
nodeNumber *= netDimensions[i];
}
this.nodes = new KohonenNode[nodeNumber];
// Generates nodes with random values
double[] randomTupel = new double[dataDimension];
for (int i = 0; i < nodes.length; i++) {
for (int j = 0; j < dataDimension; j++) {
randomTupel[j] = randomGenerator.nextDouble();
}
nodes[i] = new KohonenNode(randomTupel);
}
// Initiation was successfull
phase = 1;
updateProgressListener(10);
}
public void train() {
if (phase == 1) {
data.setRandomGenerator(this.randomGenerator);
for (int step = 1; step <= this.trainingSteps; step++) {
updateProgressListener(10 + (step - 1) * 80 / trainingSteps);
data.reset();
int fittingNode = 0;
// training over all examples
// double[] exampleWeights = new double[0]; Unused. Shevek
for (int example = 0; example < data.countData(); example++) {
double[] exampleWeights = data.getNext();
// getting coordinates in NodeNet of best fitting node
fittingNode = getBestFittingNode(exampleWeights);
int[] stimulusCoords = getCoordinatesOfIndex(fittingNode);
// adapting every node in range to stimulus
int range = 2 * (int) Math.round(adaptationFunction.getAdaptationRadius(null, step, trainingSteps));
cube(range, stimulusCoords);
while (cubeHasNext()) {
// running over the number of nodes in the hypercube
int currentNode = cubeNext();
// calculating distance in net to stimulus
double currentDistance = distanceFunction.getDistance(stimulusCoords, getCoordinatesOfIndex(currentNode), netDimensions);
// adjusting weight of node
nodes[currentNode].setWeights(adaptationFunction.adapt(exampleWeights, nodes[currentNode].getWeights(), currentDistance, step,
trainingSteps));
}
}
}
// Training has been successful: data not needed anymore
data = null;
phase = 2;
updateProgressListener(90);
informProgressExit();
}
}
private boolean cubeHasNext() {
if (cubeNodeCounter < Math.pow(cubeEdgeLength, netDimension)) {
return (true);
} else
return false;
}
private int cubeNext() {
if (cubeNodeCounter < Math.pow(cubeEdgeLength, netDimension)) {
// Calculating relative position of node in hypercube
int[] coordModifier = getCoordinatesOfIndex(cubeNodeCounter, cubeEdgeLengths);
// shifting Hypercube, so that it's centered on the stimulus
coordModifier = addArray(coordModifier, -cubeEdgeLength / 2);
// adding relative Cube coordinates to absolut position of stimulus
int[] currentCoord = addArrays(coordModifier, cubeOffset);
// getting node index in array from absolute coords.
cubeNodeCounter++;
return (getIndexOfCoordinates(currentCoord));
} else {
return -1;
}
}
private void cube(int cubeEdgeLength, int[] offset) {
cubeEdgeLengths = setArray(new int[netDimension], cubeEdgeLength);
this.cubeEdgeLength = cubeEdgeLength;
cubeOffset = offset;
cubeNodeCounter = 0;
}
public int[] apply(double[] data) {
if (phase == 2) {
int bestNode = getBestFittingNode(data);
return (getCoordinatesOfIndex(bestNode));
} else {
return (new int[] {});
}
}
public void setRandomSeed(long seed) {
if (phase == 0) {
randomSeed = seed;
}
}
public void setDistanceFunction(DistanceFunction function) {
if (phase == 0) {
this.distanceFunction = function;
}
}
public void setAdaptationFunction(AdaptationFunction function) {
if (phase == 0) {
this.adaptationFunction = function;
}
}
public void setTrainingRounds(int rounds) {
this.trainingSteps = Math.max(rounds, 1);
}
public double getDistance(double[] point1, double[] point2) {
return distanceFunction.getDistance(point1, point2);
}
public double[] getNodeWeights(int[] coords) {
return nodes[getIndexOfCoordinates(coords)].getWeights();
}
public double getNodeDistance(int nodeIndex) {
cube(3, getCoordinatesOfIndex(nodeIndex));
double distance = 0;
while (cubeHasNext()) {
distance += distanceFunction.getDistance(nodes[nodeIndex].getWeights(), nodes[cubeNext()].getWeights());
}
return distance;
}
private int getBestFittingNode(double[] dataVector) {
// initialising values
double bestDistance = Double.POSITIVE_INFINITY;
int best = -1;
// searching for best fitting node
for (int i = 0; i < nodes.length; i++) {
double currentDistance = distanceFunction.getDistance(dataVector, nodes[i].getWeights());
if (currentDistance < bestDistance) {
best = i;
bestDistance = currentDistance;
}
}
return best;
}
private int[] getCoordinatesOfIndex(int index, int[] dimensions) {
int[] coordinate = new int[dimensions.length];
for (int i = 0; i < dimensions.length; i++) {
coordinate[i] = index % dimensions[i];
index = index / dimensions[i];
}
return coordinate;
}
private int[] getCoordinatesOfIndex(int index) {
return getCoordinatesOfIndex(index, netDimensions);
}
public int getIndexOfCoordinates(int[] coordinates) {
return getIndexOfCoordinates(coordinates, netDimensions);
}
private int getIndexOfCoordinates(int[] coordinates, int[] dimensions) {
int index = 0;
for (int i = dimensions.length - 1; i >= 0; i--) {
if (coordinates[i] < 0) {
coordinates[i] = dimensions[i] + coordinates[i];
}
index *= dimensions[i];
index += Math.abs(coordinates[i] % dimensions[i]);
}
return (index);
}
private int[] addArrays(int[] array, int[] adder) {
if (array.length == adder.length) {
for (int i = 0; i < array.length; i++) {
array[i] += adder[i];
}
}
return array;
}
private int[] addArray(int[] array, int adder) {
for (int i = 0; i < array.length; i++) {
array[i] += adder;
}
return array;
}
private int[] setArray(int[] array, int value) {
for (int i = 0; i < array.length; i++) {
array[i] = value;
}
return array;
}
public void addProgressListener(ProgressListener listener) {
progressListener.add(listener);
}
public void updateProgressListener(int value) {
Iterator<ProgressListener> iterator = progressListener.iterator();
while (iterator.hasNext()) {
iterator.next().setProgress(value);
}
}
public void informProgressExit() {
Iterator<ProgressListener> iterator = progressListener.iterator();
while (iterator.hasNext()) {
iterator.next().progressFinished();
}
}
}