/* * Copyright 2004-2010 Information & Software Engineering Group (188/1) * Institute of Software Technology and Interactive Systems * Vienna University of Technology, Austria * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.ifs.tuwien.ac.at/dm/somtoolbox/license.html * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package at.tuwien.ifs.somtoolbox.layers; import java.util.ArrayList; import java.util.logging.Logger; import cern.colt.matrix.DoubleMatrix1D; import at.tuwien.ifs.somtoolbox.SOMToolboxException; import at.tuwien.ifs.somtoolbox.data.InputDatum; import at.tuwien.ifs.somtoolbox.layers.metrics.MetricException; import at.tuwien.ifs.somtoolbox.util.StdErrProgressWriter; import at.tuwien.ifs.somtoolbox.util.comparables.UnitDistance; /** * A mnemonic layer is a rectangular layer that might not have all grid positions occupied by units. * * @version $Id: MnemonicGrowingLayer.java 3590 2010-05-21 10:43:45Z mayer $ * @author Rudolf Mayer */ public class MnemonicGrowingLayer extends GrowingLayer { Integer[][][][][][] distanceMatrix_; private int unitCount; public void countDistances(int distanceFromStart, Unit startUnit, Unit currentUnit) { int xpos = currentUnit.getXPos(); int ypos = currentUnit.getYPos(); int zpos = currentUnit.getZPos(); // the unit has not been reached yet or on a longer path if (distanceMatrix_[startUnit.getXPos()][startUnit.getYPos()][startUnit.getZPos()][xpos][ypos][zpos] == null || distanceMatrix_[startUnit.getXPos()][startUnit.getYPos()][startUnit.getZPos()][xpos][ypos][zpos].intValue() > distanceFromStart) { distanceMatrix_[startUnit.getXPos()][startUnit.getYPos()][startUnit.getZPos()][xpos][ypos][zpos] = new Integer( distanceFromStart); // now we check all neighbours. if (xpos > 0 && units[xpos - 1][ypos][zpos] != null) { countDistances(distanceFromStart + 1, startUnit, units[xpos - 1][ypos][zpos]); } if (xpos + 1 < units.length && units[xpos + 1][ypos][zpos] != null) { countDistances(distanceFromStart + 1, startUnit, units[xpos + 1][ypos][zpos]); } if (ypos > 0 && units[xpos][ypos - 1][zpos] != null) { countDistances(distanceFromStart + 1, startUnit, units[xpos][ypos - 1][zpos]); } if (ypos + 1 < units[0].length && units[xpos][ypos + 1][zpos] != null) { countDistances(distanceFromStart + 1, startUnit, units[xpos][ypos + 1][zpos]); } if (zpos > 0 && units[xpos][ypos][zpos - 1] != null) { countDistances(distanceFromStart + 1, startUnit, units[xpos][ypos][zpos - 1]); } if (zpos + 1 < units[0][0].length && units[xpos][ypos][zpos + 1] != null) { countDistances(distanceFromStart + 1, startUnit, units[xpos][ypos][zpos + 1]); } } } public void initDistances() { distanceMatrix_ = new Integer[units.length][units[0].length][units[0][0].length][units.length][units[0].length][units[0][0].length]; Logger.getLogger("at.tuwien.ifs.somtoolbox").info("Calculating unit distances"); int totalUnitNum = units.length * units[0].length * units[0][0].length; StdErrProgressWriter progressWriter = new StdErrProgressWriter(totalUnitNum, "Restoring state of unit ", 10); int currentUnitNum = 0; for (int col = 0; col < units.length; col++) { for (int row = 0; row < units[0].length; row++) { for (int slice = 0; slice < units[0][0].length; slice++) { currentUnitNum++; if (units[col][row][slice] != null) { progressWriter.progress("Calculating distance of unit " + col + "/" + row + "/" + slice + ", ", (currentUnitNum + 1)); countDistances(0, units[col][row][slice], units[col][row][slice]); } else { progressWriter.progress("Skipping empty unit " + col + "/" + row + ", ", (currentUnitNum + 1)); } } } } } public MnemonicGrowingLayer(int id, Unit su, int x, int y, String metricName, int d, double[][][] vectors, long seed) throws SOMToolboxException { this(id, su, x, y, 0, metricName, d, GrowingLayer.addDimension(x, y, vectors), seed); } public MnemonicGrowingLayer(int id, Unit su, int x, int y, int z, String metricName, int d, double[][][][] vectors, long seed) throws SOMToolboxException { super(id, su, x, y, z, metricName, d, vectors, seed); initDistances(); unitCount = 0; for (int j = 0; j < ySize; j++) { for (int i = 0; i < xSize; i++) { if (units[i][j][0] != null) { unitCount++; } } } } @Override public Unit getWinner(InputDatum input) { Unit winner = null; double smallestDistance = Double.MAX_VALUE; // double[] inputVector = input.getVector().toArray(); for (int k = 0; k < zSize; k++) { for (int j = 0; j < ySize; j++) { for (int i = 0; i < xSize; i++) { if (units[i][j][k] != null) { double distance = 0; try { distance = metric.distance(units[i][j][k].getWeightVector(), input); } catch (MetricException e) { Logger.getLogger("at.tuwien.ifs.somtoolbox").severe(e.getMessage()); System.exit(-1); } if (distance < smallestDistance) { smallestDistance = distance; winner = units[i][j][k]; } } } } } return winner; } @Override public void clearMappedInput() { for (int k = 0; k < zSize; k++) { for (int j = 0; j < ySize; j++) { for (int i = 0; i < xSize; i++) { if (units[i][j][k] != null) { units[i][j][k].clearMappedInput(); if (units[i][j][k].getMappedSOM() != null) { units[i][j][k].getMappedSOM().getLayer().clearMappedInput(); } } } } } } @Override public double getMapDistance(int x1, int y1, int x2, int y2) { return getMapDistance(x1, y1, 0, x2, y2, 0); } @Override public double getMapDistance(int x1, int y1, int z1, int x2, int y2, int z2) { return distanceMatrix_[x1][y1][z1][x2][y2][z2].doubleValue(); } @Override protected void updateUnitsNormal(Unit winner, InputDatum input, double learnrate, double sigma) { double unitDist, hci = 0; double opt1 = 2 * sigma * sigma; double[] unitVector = null; double[] inputVector = input.getVector().toArray(); for (int k = 0; k < zSize; k++) { for (int j = 0; j < ySize; j++) { for (int i = 0; i < xSize; i++) { if (units[i][j] != null) { // Sparse city block distance unitDist = getMapDistance(winner, units[i][j][k]); // hci = learnrate * Math.exp((-1*Math.pow((unitDist/opt1),2))); hci = learnrate * Math.exp(-1 * unitDist * unitDist / opt1); unitVector = units[i][j][k].getWeightVector(); for (int ve = 0; ve < dim; ve++) { unitVector[ve] = unitVector[ve] + hci * (inputVector[ve] - unitVector[ve]); } } } } } } @Override public Unit[] getAllUnits() { ArrayList<Unit> tempUnits = new ArrayList<Unit>(xSize * ySize * zSize / 2); for (int k = 0; k < zSize; k++) { for (int j = 0; j < ySize; j++) { for (int i = 0; i < xSize; i++) { if (units[i][j][k] != null) { tempUnits.add(units[i][j][k]); } } } } return tempUnits.toArray(new Unit[tempUnits.size()]); } @Override public UnitDistance[] getWinnersAndDistances(InputDatum input, int num) { int maxNum = 0; for (int k = 0; k < zSize; k++) { for (int j = 0; j < ySize; j++) { for (int i = 0; i < xSize; i++) { if (units[i][j][k] != null) { maxNum++; } } } } if (num > maxNum) { num = maxNum; } UnitDistance[] res = new UnitDistance[num]; DoubleMatrix1D vec = input.getVector(); for (int k = 0; k < zSize; k++) { for (int j = 0; j < ySize; j++) { for (int i = 0; i < xSize; i++) { if (units[i][j][k] != null) { double distance = 0; try { distance = metric.distance(units[i][j][k].getWeightVector(), vec); } catch (MetricException e) { Logger.getLogger("at.tuwien.ifs.somtoolbox").severe(e.getMessage()); System.exit(-1); } int element = 0; boolean inserted = false; while (inserted == false && element < num) { if (res[element] == null || distance < res[element].getDistance()) { // found place to // insert unit for (int m = num - 2; m >= element; m--) { // move units with greater distance to // right res[m + 1] = res[m]; } res[element] = new UnitDistance(units[i][j][k], distance); inserted = true; } element++; } } } } } return res; } @Override public int getUnitCount() { return unitCount; } }