/* * Copyright 2004-2010 Information & Software Engineering Group (188/1) * Institute of Software Technology and Interactive Systems * Vienna University of Technology, Austria * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.ifs.tuwien.ac.at/dm/somtoolbox/license.html * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package at.tuwien.ifs.somtoolbox.layers; import java.util.ArrayList; import java.util.Random; import java.util.logging.Logger; import at.tuwien.ifs.somtoolbox.SOMToolboxException; import at.tuwien.ifs.somtoolbox.data.InputData; import at.tuwien.ifs.somtoolbox.data.InputDatum; import at.tuwien.ifs.somtoolbox.data.TemplateVector; import at.tuwien.ifs.somtoolbox.layers.metrics.MetricException; import at.tuwien.ifs.somtoolbox.models.GrowingSOM; import at.tuwien.ifs.somtoolbox.util.StringUtils; import at.tuwien.ifs.somtoolbox.util.VectorTools; /** * Represents a unit on a map. It has a position in terms of x and y coordinates and an n-dimensional weight vector. * Data can be mapped onto a <code>Unit</code>. Labels can be assigned to a <code>Unit</code> to describe the mapped * data. A <code>Unit</code> can also have an assigned map for use in hierarchical models.<br> * TODO: The type of <code>mappedSOM</code> should be made more general than it is now. FIXME: similar to what should be * done with {@link GrowingLayer}, we should make a Unit3D version out of this for 3D SOMs, to keep the memory * fingerprint small, and avoid overloading many methods with 2D and 3D params, which makes it very hard to read * * @author Michael Dittenbach * @version $Id: Unit.java 3621 2010-07-07 13:33:28Z mayer $ */ public class Unit extends InputContainer { /** * Types of feature weighting modes * * @author Rudolf Mayer */ public enum FeatureWeightMode { GLOBAL, // GLOBAL weighting as in Nürnberger & Detyniecki LOCAL, // LOCAL weighting as in Nürnberger & Detyniecki GENERAL, // GENERAL weighting as in Nürnberger & Detyniecki } public static final String KEYWORDS = "KEYWORDS"; public static final String GATE = "GATE"; public static final String CONTEXT = "CONTEXT"; public static final String LABELSOM = "LabelSOM"; private static final int INIT_RANDOM = 10; private static final int INIT_INTERVAL_INTERPOLATE = 20; private static final int INIT_VECTOR = 30; private static final int INIT_PCA = 40; private int dim = 0; // TODO: find a more generic way for labelling private Label[] labels = null; // LabelSOM labels private Label[] kaskiGateLabels = null; // private Label[] kaskiLabels = null; private Label[] gateWeights = null; private Label[] bestcontextWeights = null; private Label[] contextGateLabels = null; private Layer layer = null; private GrowingSOM mappedSOM = null; // TODO: Should be NetworkModel or similar in the future private double quantizationError = 0; private double[] weightVector = null; private double[] featureWeights = null; private int xPos = -1; private int yPos = -1; private int zPos = -1; private ArrayList<InputDatum> batchSomNeighbourhood; /** * Constructs a <code>Unit</code> on <code>Layer</code> specified by argument <code>layer</code> at position * <code>x</code>/<code>y</code> with a given weight vector <code>vec</code>. * * @param l the layer that contains this <code>Unit</code>. * @param x the horizontal position on the <code>layer</code>. * @param y the vertical position on the <code>layer</code>. * @param vec the weight vector. */ public Unit(Layer l, int x, int y, double[] vec) { this(l, x, y, 0, vec); } /** * Constructs a <code>Unit</code> on <code>Layer</code> specified by argument <code>layer</code> at position * <code>x</code>/<code>y</code> with a given weight vector <code>vec</code>. * * @param l the layer that contains this <code>Unit</code>. * @param x the horizontal position on the <code>layer</code>. * @param y the vertical position on the <code>layer</code>. * @param z the depth position on the <code>layer</code>. * @param vec the weight vector. */ public Unit(Layer l, int x, int y, int z, double[] vec) { layer = l; xPos = x; yPos = y; zPos = z; weightVector = vec; dim = vec.length; // FIXME: don't initialise this here - if we don't use the batchmode, we just waste memory... batchSomNeighbourhood = new ArrayList<InputDatum>(); } /** * Constructs a <code>Unit</code> on <code>Layer</code> specified by argument <code>layer</code> at position * <code>x</code>/<code>y</code> with a randomly initialized weight vector of dimension <code>d</code>. Argument * <code>norm</code> determines whether the weight vector should be normalized or not. TODO: This might be change in * the future due to unflexibility regarding hard coded normalization methods. * * @param l the layer that contains this <code>Unit</code>. * @param x the horizontal position on the <code>layer</code>. * @param y the vertical position on the <code>layer</code>. * @param d the dimensionality of the weight vector. * @param rand a random number generator provided by the caller. * @param norm the type of normalization (see text above). */ public Unit(Layer l, int x, int y, int d, Random rand, boolean norm) { this(l, x, y, 0, d, rand, norm, INIT_RANDOM); } /** * Constructs a <code>Unit</code> on <code>Layer</code> specified by argument <code>layer</code> at position * <code>x</code>/<code>y</code> with a randomly initialized weight vector of dimension <code>d</code>. Argument * <code>norm</code> determines whether the weight vector should be normalized or not. TODO: This might be change in * the future due to unflexibility regarding hard coded normalization methods. * * @param l the layer that contains this <code>Unit</code>. * @param x the horizontal position on the <code>layer</code>. * @param y the vertical position on the <code>layer</code>. * @param z the depth position on the <code>layer</code>. * @param d the dimensionality of the weight vector. * @param rand a random number generator provided by the caller. * @param norm the type of normalization (see text above). */ public Unit(Layer l, int x, int y, int z, int d, Random rand, boolean norm) { this(l, x, y, z, d, rand, norm, INIT_RANDOM); } public Unit(Layer l, int x, int y, int d, Random rand, boolean norm, int initialisationMode) { this(l, x, y, 0, d, rand, norm, initialisationMode); } public Unit(Layer l, int x, int y, int z, int d, Random rand, boolean norm, int initialisationMode) { layer = l; xPos = x; yPos = y; zPos = z; dim = d; batchSomNeighbourhood = new ArrayList<InputDatum>(); weightVector = new double[dim]; if (initialisationMode == INIT_RANDOM) { for (int i = 0; i < dim; i++) { weightVector[i] = rand.nextDouble(); } } else if (initialisationMode == INIT_INTERVAL_INTERPOLATE) { for (int i = 0; i < dim; i++) { double r = rand.nextDouble(); double[][] intervals = ((GrowingLayer) l).getData().getDataIntervals(); weightVector[i] = intervals[i][0] + (intervals[i][1] - intervals[i][0]) * r; } } else if (initialisationMode == INIT_VECTOR) { double r = rand.nextDouble(); int index = (int) (((GrowingLayer) l).getData().numVectors() * r); weightVector = ((GrowingLayer) l).getData().getInputDatum(index).getVector().toArray(); // System.out.println("Initialised " + x + "/" + y + "as:" + VectorTools.printVector(weightVector)); } else if (initialisationMode == INIT_PCA) { // TODO: do PCA } if (norm) { VectorTools.normaliseVectorToUnitLength(weightVector); } } /** * Adds a single input datum to the unit. The method also calculates the distance between the unit's weight vector * and the datum. * * @param datum the input datum to be added. * @param calcQE determines if the quantization error should be recalculated. * @see #addMappedInput(String, double, boolean) */ public void addMappedInput(InputDatum datum, boolean calcQE) { try { double dist = layer.getMetric().distance(datum.getVector(), this.weightVector); addMappedInput(datum.getLabel(), dist, calcQE); } catch (MetricException e) { Logger.getLogger("at.tuwien.ifs.somtoolbox").severe(e.getMessage()); System.exit(-1); // TODO: EXCEPTION HANDLING!! } } /** * Map all the input vectors contained in specified <code>InputData</code> object onto this unit. * * @param data The container for input vector * @param calcQE determines if the quantization error should be recalculated. * @see #addMappedInput(InputDatum, boolean) */ public void addMappedInput(InputData data, boolean calcQE) { for (int d = 0; d < data.numVectors(); d++) { addMappedInput(data.getInputDatum(d), calcQE); } } /** * Convenience method to add an input datum specified by its name and distance. The quantization error is * recalculated if argument <code>calcQE</code> is <code>true</code>. * * @param name the name of the input datum. * @param dist the precalculated distance between input datum and weight vector * @param calcQE determines if the quantization error should be recalculated. */ public void addMappedInput(String name, double dist, boolean calcQE) { super.addMappedInput(name, new Double(dist)); if (calcQE) { calculateQuantizationError(); } } @Override public void removeMappedInput(String label) { super.removeMappedInput(label); calculateQuantizationError(); } /** * Recalculates the quantization error for this unit. */ public void calculateQuantizationError() { quantizationError = 0; for (int i = 0; i < getNumberOfMappedInputs(); i++) { quantizationError += getMappedInputDistance(i); } } /** * Removes the labels of this unit. */ public void clearLabels() { labels = null; } /** * Removes the mapped input data and sets this units quantization error to 0. */ public void clearMappedInput() { super.clearMappedInputs(); quantizationError = 0; } /** * Returns an array of labels or null, if no labels are assigned to this unit. * * @return an array of labels or null. */ public Label[] getLabels(String type) { if (type == LABELSOM) { return getLabels(); } else if (type == KEYWORDS) { return getKaskiLabels(); } else if (type == GATE) { return getKaskiGateLabels(); } else if (type == CONTEXT) { return getBestContextWeights(); } else { return null; } } public Label[] getLabels() { return labels; } public Label[] getKaskiGateLabels() { return kaskiGateLabels; } public Label[] getKaskiLabels() { return kaskiLabels; } public Label[] getGateWeights() { return gateWeights; } public Label[] getBestContextWeights() { return bestcontextWeights; } public Label[] getContextGateLabels() { return contextGateLabels; } /** * Returns the layer of units this unit is part of. * * @return the layer of units this unit is part of. */ public Layer getLayer() { return layer; } /** * Returns the map identification string of this unit's layer. * * @return the map identification string of this unit's layer. */ public String getMapIdString() { return layer.getIdString(); } /** * Returns the level of this unit's layer in a hierarchy of maps. * * @return the level of this unit's layer in a hierarchy of maps. */ public int getMapLevel() { return layer.getLevel(); } /** * Returns the map assigned to this unit or <code>null</code> otherwise. * * @return the map assigned to this unit or <code>null</code> otherwise. */ public GrowingSOM getMappedSOM() { return mappedSOM; } /** * Assigns a map to this unit. * * @param mappedSOM a map to be assigned to this unit. */ public void setMappedSOM(GrowingSOM mappedSOM) { this.mappedSOM = mappedSOM; } /** * Returns the width of this unit's map. * * @return the width of this unit's map. */ public int getMapXSize() { return layer.getXSize(); } /** * Returns the height of this unit's map. * * @return the height of this unit's map. */ public int getMapYSize() { return layer.getYSize(); } /** * Returns the depth of this unit's map. * * @return the depth of this unit's map. */ public int getMapZSize() { return layer.getZSize(); } /** * Calculates and returns the mean quantization error of this unit. This is 0, if no input is mapped onto this unit. * * @return the mean quantization error for this unit. */ /* * public double getMeanQuantizationError() { if (mappedInputs.getNumberOfMappedInputs()>0) { return * (quantizationError/mappedInputs.getNumberOfMappedInputs()); } else { return 0; } } */ /** * Returns the quantization error of this unit. * * @return the quantization error of this unit. */ /* * public double getQuantizationError() { return quantizationError; } */ /** * Returns the weight vector of this unit. * * @return the weight vector of this unit. */ public double[] getWeightVector() { return weightVector; } /** * Sets the weight vector of this unit. * * @param vector the weight vector. */ public void setWeightVector(double[] vector) throws SOMToolboxException { if (vector != null && vector.length == dim) { weightVector = vector; } else { throw new SOMToolboxException("Vector is null or has wrong dimensionality."); } } /** * Returns the horizontal position of this unit on the map it is part of. * * @return the horizontal position of this unit on the map it is part of. */ public int getXPos() { return xPos; } /** * Returns the vertical position of this unit on the map it is part of. * * @return the vertical position of this unit on the map it is part of. */ public int getYPos() { return yPos; } /** * Returns the depth position of this unit on the map it is part of. * * @return the depth position of this unit on the map it is part of. */ public int getZPos() { return zPos; } /** * Sets this unit's weight vector to the vector of the input datum specified by argument <code>datum</code>. * * @param datum the input datum. */ public void initWeightVectorBySample(InputDatum datum) { weightVector = datum.getVector().toArray(); } /** * Restores the labels of a unit based on the information provided by the arguments. The value of argument * <code>nrUnitLabels</code> must be equal to the dimensionalities of the arrays specified in the other arguments. * If this is not the case, no labels will be restored. * * @param nrUnitLabels the number of labels. * @param unitLabels an array of strings containing the labels' names. * @param unitLabelsQe an array of double values containing the qunatization errors for the single labels. * @param unitLabelsWgt an array of double values containing the actual values for the single labels. */ public void restoreLabels(int nrUnitLabels, String[] unitLabels, double[] unitLabelsQe, double[] unitLabelsWgt) { if (nrUnitLabels > 0 && unitLabels.length == nrUnitLabels && unitLabelsQe.length == nrUnitLabels && unitLabelsWgt.length == nrUnitLabels) { labels = new Label[nrUnitLabels]; for (int i = 0; i < nrUnitLabels; i++) { labels[i] = new Label(unitLabels[i], unitLabelsWgt[i], unitLabelsQe[i]); } } } public void restoreContextGateLabels(int nrContextGate, String[] contextGateUnitLabels) { contextGateLabels = new Label[nrContextGate]; for (int i = 0; i < nrContextGate; i++) { contextGateLabels[i] = new Label(contextGateUnitLabels[i]); } } public void restoreKaskiLabels(int nrKaski, String[] kaskiUnitLabels, double[] kaskiUnitLabelsWgt) { kaskiLabels = new Label[nrKaski]; for (int i = 0; i < nrKaski; i++) { kaskiLabels[i] = new Label(kaskiUnitLabels[i], kaskiUnitLabelsWgt[i]); } } public void restoreKaskiGateLabels(int nrKaskiGate, String[] kaskiGateUnitabels) { kaskiGateLabels = new Label[nrKaskiGate]; for (int i = 0; i < nrKaskiGate; i++) { kaskiGateLabels[i] = new Label(kaskiGateUnitabels[i]); } } public void restoreGateWeightLabels(int nrgateweights, String[] gateWeightUnitLabels) { gateWeights = new Label[nrgateweights]; for (int i = 0; i < nrgateweights; i++) { gateWeights[i] = new Label(gateWeightUnitLabels[i]); } } public void restoreBestContextWeightLabels(int nrbestcontext, String[] bestContextWeightUnitLabels) { bestcontextWeights = new Label[nrbestcontext]; for (int i = 0; i < nrbestcontext; i++) { bestcontextWeights[i] = new Label(bestContextWeightUnitLabels[i]); } } /** * Restores the mapped input data of a unit based on the information provided by the arguments. The value of * argument <code>nrVecsMapped</code> must be equal to the dimensionalities of the arrays specified in the other * arguments. If this is not the case, no input data will be restored. The quantization error will also be * recalculated. * * @param nrVecsMapped the number of input data. * @param mappedVecs an array of strings containing the data identifiers. * @param mappedVecsDist an array of double values containing the distances between the weight vector and the * respective input data. */ public void restoreMappings(int nrVecsMapped, String[] mappedVecs, double[] mappedVecsDist) { for (int i = 0; i < nrVecsMapped; i++) { addMappedInput(mappedVecs[i], mappedVecsDist[i], false); } calculateQuantizationError(); } /** * Assigns labels to this unit. * * @param labels array of labels to be assigned to this unit. */ public void setLabels(Label[] labels) { this.labels = labels; } public void setKaskiGateLabels(Label[] kaski_gate_labels) { this.kaskiGateLabels = kaski_gate_labels; } public void setContextGateLabels(Label[] context_gate_labels) { this.contextGateLabels = context_gate_labels; } /** * Sets the coordinates of this unit on the map, if they have changed. This happens in architectures with growing * map sizes during training. * * @param x the horizontal position on the map. * @param y the vertical position on the map. * @param z the height position on the map. */ public void updatePosition(int x, int y, int z) { xPos = x; yPos = y; zPos = z; } public void updatePosition(int x, int y) { updatePosition(x, y, 0); } public void addBatchSomNeighbour(InputDatum d) { batchSomNeighbourhood.add(d); } public void clearBatchSomList() { batchSomNeighbourhood.clear(); } public void getWeightVectorFromBatchSomNeighbourhood() { double meanValue; for (int i = 0; i < weightVector.length; i++) { meanValue = 0; for (int j = 0; j < batchSomNeighbourhood.size(); j++) { meanValue += batchSomNeighbourhood.get(j).getVector().get(i); } meanValue = meanValue / batchSomNeighbourhood.size(); weightVector[i] = meanValue; } } @Override public String toString() { return "Unit[" + printCoordinates() + "]"; } public String printCoordinates() { if (getMapZSize() > 1) { return xPos + "/" + yPos + "/" + zPos; } else { return xPos + "/" + yPos; } } public String printCoordinatesSpaceSeparated() { if (getMapZSize() > 1) { return xPos + " " + yPos + " " + zPos; } else { return xPos + " " + yPos; } } public String printUnitDetails(InputData inputData, TemplateVector tv) { StringBuffer sb = new StringBuffer("Unit details for ").append(getXPos()).append("/").append(getYPos()).append( ", ").append(getNumberOfMappedInputs()).append(" mapped inputs:\n"); final int numSpace = 6; final int firstColumnSpace = 10; final int numDigits = numSpace - 3; if (tv != null) { sb.append(StringUtils.getSpaces(firstColumnSpace + 1)); for (int i = 0; i < tv.dim(); i++) { String label = tv.getLabel(i).length() >= numSpace ? tv.getLabel(i).substring(0, numSpace - 1) : tv.getLabel(i); sb.append(label).append(StringUtils.getSpaces(numSpace - label.length())); } sb.append("\n"); } final String headerWeightVector = "WeightVec"; final String headerFeatureWeight = "FeatWeight"; sb.append(headerWeightVector).append(StringUtils.getSpaces(firstColumnSpace - headerWeightVector.length())).append( StringUtils.toStringWithPrecision(getWeightVector(), numDigits)).append("\n"); if (featureWeights != null) { sb.append(headerFeatureWeight).append( StringUtils.getSpaces(firstColumnSpace - headerFeatureWeight.length())).append( StringUtils.toStringWithPrecision(featureWeights, numDigits)).append("\n"); } if (getNumberOfMappedInputs() > 0) { for (String label : getMappedInputNames()) { sb.append(label).append(StringUtils.getSpaces(firstColumnSpace - label.length())); if (inputData != null && inputData.getInputDatum(label) != null) { sb.append( StringUtils.toStringWithPrecision(inputData.getInputDatum(label).getVector().toArray(), numDigits)).append("\n"); } } } return sb.toString(); } public double[] getFeatureWeights() { return featureWeights; } public void setFeatureWeights(double[] featureWeights) { this.featureWeights = featureWeights; } public void copyFeatureWeights(double[] featureWeights) { for (int i = 0; i < featureWeights.length; i++) { this.featureWeights[i] = featureWeights[i]; } } public int getDim() { return dim; } public String getUnitLabels() { StringBuffer label = new StringBuffer(); if (labels != null) { for (int i = 0; i < labels.length; i++) { if (i > 0) { label.append(", "); } label.append(labels[i].getName()); } } if (labels == null || labels.length == 0) { label.append("<no labels available>"); } return label.toString(); } void setPositions(int x, int y, int z) { this.xPos = x; this.yPos = y; this.yPos = y; } public boolean isTopLeftUnit() { return xPos == 0 && yPos == 0 && zPos == 0; } }