/*
* Copyright 2004-2010 Information & Software Engineering Group (188/1)
* Institute of Software Technology and Interactive Systems
* Vienna University of Technology, Austria
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.ifs.tuwien.ac.at/dm/somtoolbox/license.html
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package at.tuwien.ifs.somtoolbox.layers.quality;
import java.util.logging.Logger;
import at.tuwien.ifs.somtoolbox.data.InputData;
import at.tuwien.ifs.somtoolbox.data.InputDatum;
import at.tuwien.ifs.somtoolbox.layers.Layer;
import at.tuwien.ifs.somtoolbox.layers.LayerAccessException;
import at.tuwien.ifs.somtoolbox.layers.Unit;
import at.tuwien.ifs.somtoolbox.layers.metrics.MetricException;
/**
* Implementation of SOM Distortion Measure Quality.
*
* @author Michael Dittenbach
* @version $Id: SOMDistortion.java 3883 2010-11-02 17:13:23Z frank $
*/
public class SOMDistortion extends AbstractQualityMeasure {
private double distortion;
// private double totalDistortion;
private double[][] hits;
private double[][] unitAverage;
private double[][] unitTotal;
public SOMDistortion(Layer layer, InputData data) {
super(layer, data);
mapQualityNames = new String[] { "distortion" };
mapQualityDescriptions = new String[] { "SOM Distortion" };
unitQualityNames = new String[] { "unitAverage", "unitTotal" };
unitQualityDescriptions = new String[] { "Average Unit Distortion", "Total Unit Distortion" };
int xSize = layer.getXSize();
int ySize = layer.getYSize();
unitTotal = new double[xSize][ySize];
unitAverage = new double[xSize][ySize];
hits = new double[xSize][ySize];
for (int j = 0; j < ySize; j++) {
for (int i = 0; i < xSize; i++) {
unitTotal[i][j] = 0;
unitAverage[i][j] = 0;
hits[i][j] = 0;
}
}
distortion = 0;
double[][] dist2 = new double[xSize][ySize];
try {
for (int d = 0; d < data.numVectors(); d++) {
InputDatum datum = data.getInputDatum(d);
double minDist2 = Double.MAX_VALUE;
int bmuX = -1;
int bmuY = -1;
// calculate squared distances and remember BMU
for (int j = 0; j < ySize; j++) {
for (int i = 0; i < xSize; i++) {
dist2[i][j] = squaredDistance(datum, layer.getUnit(i, j).getWeightVector());
if (dist2[i][j] < minDist2) {
minDist2 = dist2[i][j];
bmuX = i;
bmuY = j;
}
}
}
// calculate total distortion of units
Unit bmu = layer.getUnit(bmuX, bmuY);
for (int j = 0; j < ySize; j++) {
for (int i = 0; i < xSize; i++) {
unitTotal[i][j] += dist2[i][j]
* neighborhoodFunction(layer.getMapDistance(bmu, layer.getUnit(i, j)));
}
}
// increase hit variable for BMU
hits[bmuX][bmuY]++;
}
// calculate average distortion per unit, sum up total distortions
for (int j = 0; j < ySize; j++) {
for (int i = 0; i < xSize; i++) {
distortion += unitTotal[i][j];
if (hits[i][j] > 0) {
unitAverage[i][j] = unitTotal[i][j] / hits[i][j];
}
}
}
// totalDistortion = distortion;
// average distortion measure
distortion = distortion / data.numVectors();
} catch (MetricException me) {
Logger.getLogger("at.tuwien.ifs.somtoolbox").severe(me.getMessage());
System.exit(-1);
} catch (LayerAccessException lae) {
// TODO: this does not happen
}
dist2 = null;
}
@Override
public double getMapQuality(String name) throws QualityMeasureNotFoundException {
if (name.equals("distortion")) {
return distortion;
} else {
throw new QualityMeasureNotFoundException("Quality measure with name " + name + " not found.");
}
}
@Override
public double[][] getUnitQualities(String name) throws QualityMeasureNotFoundException {
if (name.equals("unitTotal")) {
return unitTotal;
} else if (name.equals("unitAverage")) {
return unitAverage;
} else {
throw new QualityMeasureNotFoundException("Quality measure with name " + name + " not found.");
}
}
private double neighborhoodFunction(double dist) {
// e^-(d^2/2*sigma^2)
// return Math.exp((-1*dist*dist)/(2*0.01*0.01));
return Math.exp(-1 * dist * dist / 0.002);
}
private double squaredDistance(InputDatum datum, double[] vector2) throws MetricException {
double[] vector1 = datum.getVector().toArray();
if (vector1.length != vector2.length) {
throw new MetricException(
"Oops ... tried to calculate distance between two vectors with different dimensionalities.");
}
double dist = 0;
int dim = vector1.length;
for (int ve = 0; ve < dim; ve++) {
dist += (vector1[ve] - vector2[ve]) * (vector1[ve] - vector2[ve]);
}
return dist;
}
}