/* * Copyright 2004-2010 Information & Software Engineering Group (188/1) * Institute of Software Technology and Interactive Systems * Vienna University of Technology, Austria * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.ifs.tuwien.ac.at/dm/somtoolbox/license.html * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package at.tuwien.ifs.somtoolbox.data; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.Hashtable; import java.util.LinkedHashMap; import java.util.Random; import java.util.Set; import java.util.logging.Logger; import cern.colt.Sorting; import cern.colt.function.IntComparator; import cern.colt.matrix.DoubleMatrix1D; import cern.colt.matrix.impl.DenseDoubleMatrix1D; import cern.jet.math.Functions; import at.tuwien.ifs.somtoolbox.SOMToolboxException; import at.tuwien.ifs.somtoolbox.data.distance.LeightWeightMemoryInputVectorDistanceMatrix; import at.tuwien.ifs.somtoolbox.layers.metrics.DistanceMetric; import at.tuwien.ifs.somtoolbox.layers.metrics.MetricException; import at.tuwien.ifs.somtoolbox.util.StdErrProgressWriter; import at.tuwien.ifs.somtoolbox.util.comparables.InputDistance; /** * This abstract implementation provides basic support for operating on a {@link InputData}. Sub-classes have to * implement constructors and methods to read input vectors and create an <code>InputData</code> object, for example by * reading from a file or a database. * * @author Rudolf Mayer * @version $Id: AbstractSOMLibSparseInputData.java 3883 2010-11-02 17:13:23Z frank $ */ public abstract class AbstractSOMLibSparseInputData implements InputData { protected static final String ERROR_MESSAGE_FILE_FORMAT_CORRUPT = "Input vector file format corrupt. Aborting."; /** Where this input data was read from, e.g. a file or database table */ protected String source; /** * Any class label information attached to the input vectors. */ protected SOMLibClassInformation classInfo = null; /** * The label/name of the vector. */ public String[] dataNames = null; /** * The content type of the vectors ("text", "audio", ...). <br> * <p> * An input file should use the following header format for content types: <br> * <code>$DATA_TYPE text</code> <br> * or <br> * <code>$DATA_TYPE audio-rp</code> * </p> */ protected String content_type = ""; /** * The specific subtype of content type (user-definable, for example "rp", "rh", or "ssd" for Rhythm Patterns, * Rhythm Histograms or Statistical Spectrum Descriptor audio feature types). */ protected String content_subtype = ""; /** * Row dimension of the feature matrix before having been vectorized to input vector. */ protected int featureMatrixRows = -1; /** * Column dimension of the feature matrix before having been vectorized to input vector. */ protected int featureMatrixCols = -1; /** * The dimension of the input vectors, i.e. the number of attributes */ protected int dim = 0; /** * Indicates whether the input data has been normalised. */ protected boolean isNormalized = true; /** * The mean of all the input vectors. */ protected DenseDoubleMatrix1D meanVector = null; protected double mqe0 = -1; // value of -1 means the mqe was not yet calculated /** * The number of vectors in this input data collection. */ protected int numVectors = 0; protected Random rand = null; /** * A {@link TemplateVector} attached to this input data. */ protected TemplateVector templateVector = null; /** * A transformation of the input vectors. This can be used to perform for example a transformation of the input data * for distance calculations once for all vectors to improve performance. */ private double[][] transformedVectors; /** * A matrix containing the pairwise distances between two vectors.<br/> * FIXME: use {@link LeightWeightMemoryInputVectorDistanceMatrix} instead */ private double[][] distanceMatrix; /** * A mapping from the name to the index of an input vector, for faster access. */ protected LinkedHashMap<String, Integer> nameCache = null; private double[][] intervals; protected AbstractSOMLibSparseInputData(String[] dataNames, int dim, boolean norm, Random rand, TemplateVector tv, SOMLibClassInformation clsInfo) { this(norm, rand); this.dataNames = dataNames; this.dim = dim; this.numVectors = dataNames.length; meanVector = new DenseDoubleMatrix1D(dim); this.templateVector = tv; this.classInfo = clsInfo; } protected AbstractSOMLibSparseInputData(boolean norm, Random random) { this.isNormalized = norm; this.rand = random; } protected AbstractSOMLibSparseInputData() { } @Override public int dim() { return dim; } @Override public String getContentType() { return content_type; } @Override public String getContentSubType() { return content_subtype; } @Override public int getFeatureMatrixRows() { return featureMatrixRows; } @Override public int getFeatureMatrixColumns() { return featureMatrixCols; } @Override public DoubleMatrix1D getMeanVector() { return meanVector; } @Override public DoubleMatrix1D getMeanVector(String[] labels) { if (labels.length == 0) { return null; } InputDatum[] vectors = getInputDatum(labels); meanVector = new DenseDoubleMatrix1D(dim); for (int i = 0; i < labels.length; i++) { meanVector.assign(vectors[i].getVector(), Functions.plus); // add to mean vector } meanVector.assign(Functions.div(labels.length)); // calculating mean vector return meanVector; } @Override public boolean isNormalizedToUnitLength() { return isNormalized; } @Override public int numVectors() { return numVectors; } @Override public TemplateVector templateVector() { return templateVector; } @Override public SOMLibClassInformation classInformation() { return classInfo; } @Override public void setTemplateVector(TemplateVector templateVector) { this.templateVector = templateVector; } @Override public InputDatum getInputDatum(String label) { if (nameCache.get(label) != null) { return getInputDatum(nameCache.get(label).intValue()); } else { return null; } } public int getInputDatumIndex(String label) { if (nameCache.get(label) != null) { return nameCache.get(label).intValue(); } else { return -1; } } @Override public InputDatum getRandomInputDatum(int iteration, int numIterations) { // Get a random number int randIndex = rand.nextInt(numVectors); return this.getInputDatum(randIndex); } @Override public InputDatum[] getInputDatum(String[] labels) { if (labels == null) { return null; } else { InputDatum[] res = new InputDatum[labels.length]; int[] indices = new int[labels.length]; for (int i = 0; i < labels.length; i++) { indices[i] = nameCache.get(labels[i]).intValue(); } IntComparator comp = new IntComparator() { /** * @see cern.colt.function.IntComparator#compare(int, int) */ @Override public int compare(int o1, int o2) { return o1 < o2 ? -1 : o1 == o2 ? 0 : 1; } }; Sorting.quickSort(indices, 0, indices.length - 1, comp); for (int i = 0; i < labels.length; i++) { res[i] = this.getInputDatum(indices[i]); } return res; } } /** * Calculates the matrix of {@link #transformedVectors} using {@link DistanceMetric#transformVector(double[])} of * the given metric. * * @param metric the metric to be used to transform the values. */ public void transformValues(DistanceMetric metric) { transformedVectors = new double[numVectors()][dim()]; for (int i = 0; i < numVectors(); i++) { transformedVectors[i] = metric.transformVector(getInputDatum(i).getVector().toArray()); } } /** * Calculates the {@link #distanceMatrix} - careful, this is a lengthy process and should be done only if needed. * Requires the matrix of {@link #transformedVectors} being initialised (e.g. via * {@link #transformValues(DistanceMetric)}). * * @param metric the metric to use for calculating the distances. * @throws MetricException if {@link DistanceMetric#distance(double[], double[])} encounters a problem. */ public void initDistanceMatrix(DistanceMetric metric) throws MetricException { distanceMatrix = new double[numVectors()][numVectors()]; if (transformedVectors == null) { Logger.getLogger("at.tuwien.ifs.somtoolbox").info("Empty transformed matrix, taking vector values"); transformedVectors = new double[numVectors()][dim()]; for (int i = 0; i < numVectors(); i++) { transformedVectors[i] = getInputDatum(i).getVector().toArray(); } } StdErrProgressWriter progress = new StdErrProgressWriter(numVectors(), "pre-calculating distances: ", numVectors() / 10); for (int i = 0; i < numVectors(); i++) { for (int j = i + 1; j < numVectors(); j++) { distanceMatrix[i][j] = metric.distance(transformedVectors[i], transformedVectors[j]); distanceMatrix[j][i] = distanceMatrix[i][j]; } progress.progress(i); } } /** * Returns the n nearest input vectors for the index of the given vector of the dataset. Uses a pre-calculated * distance metric, if existing, otherwise calculates the distances as needed. * * @param inputIndex the index of the vector. * @param metric the metric to use for the distance comparison. Only used when the {@link #distanceMatrix} is not * pre-calculated. * @param number the number of nearest input vectors desired. * @return the n nearest input vectors. * @throws MetricException if {@link DistanceMetric#distance(DoubleMatrix1D, double[])} encounters a problem. */ public InputDatum[] getNearestN(int inputIndex, DistanceMetric metric, int number) throws MetricException { return getNNearest(number, getDistances(inputIndex, metric)); } /** * Returns the distances to the index of the given vector of the dataset. Uses a pre-calculated distance metric, if * existing, otherwise calculates the distances as needed. * * @param inputIndex the index of the vector. * @param metric the metric to use for the distance comparison. Only used when the {@link #distanceMatrix} is not * pre-calculated. * @return the n nearest input vectors. * @throws MetricException if {@link DistanceMetric#distance(DoubleMatrix1D, double[])} encounters a problem. */ public ArrayList<InputDistance> getDistances(int inputIndex, DistanceMetric metric) throws MetricException { InputDatum input = getInputDatum(inputIndex); ArrayList<InputDistance> distances = new ArrayList<InputDistance>(numVectors() - 1); if (distanceMatrix != null) { for (int i = 0; i < distanceMatrix[inputIndex].length; i++) { if (inputIndex != i) { distances.add(new InputDistance(distanceMatrix[inputIndex][i], getInputDatum(i))); } } } else { for (int i = 0; i < numVectors(); i++) { if (!getInputDatum(i).equals(input)) { if (transformedVectors != null) { distances.add(new InputDistance(metric.distance(input.getVector(), transformedVectors[i]), getInputDatum(i))); } else { distances.add(new InputDistance( metric.distance(input.getVector(), getInputDatum(i).getVector()), getInputDatum(i))); } } } } return distances; } private InputDatum[] getNNearest(ArrayList<InputDistance> distances) { return getNNearest(distances.size(), distances); } private InputDatum[] getNNearest(int number, ArrayList<InputDistance> distances) { Collections.sort(distances); InputDatum[] result = new InputDatum[number]; for (int i = 0; i < number; i++) { result[i] = distances.get(i).getInput(); } return result; } // FIXME: what's the difference to #getNearestN ? public InputDatum[] getNearestNUnsorted(int inputIndex, DistanceMetric metric, int number) throws MetricException { InputDatum input = getInputDatum(inputIndex); double longestDistance = Double.MAX_VALUE; if (distanceMatrix != null) { ArrayList<InputDatum> distances = new ArrayList<InputDatum>(); for (int i = 0; i < 6 && i < distanceMatrix[inputIndex].length; i++) { distances.add(getInputDatum(i)); if (distanceMatrix[inputIndex][i] < longestDistance) { longestDistance = distanceMatrix[inputIndex][i]; } } for (int i = 6; i < distanceMatrix[inputIndex].length; i++) { if (inputIndex != i) { if (distanceMatrix[inputIndex][i] < longestDistance) { distances.add(getInputDatum(i)); } } } return distances.toArray(new InputDatum[distances.size()]); } else { ArrayList<InputDistance> distances = new ArrayList<InputDistance>(); for (int i = 0; i < numVectors(); i++) { if (!getInputDatum(i).equals(input)) { distances.add(new InputDistance(metric.distance(input.getVector(), transformedVectors[i]), getInputDatum(i))); } } return getNNearest(number, distances); } } /** Retrieves the given number of {@link InputDatum} that are closest to the given vector. */ public InputDatum[] getNearestN(double[] vector, DistanceMetric metric, int number) throws MetricException { ArrayList<InputDistance> distances = new ArrayList<InputDistance>(); for (int i = 0; i < numVectors(); i++) { distances.add(new InputDistance(metric.distance(vector, getInputDatum(i).getVector().toArray()), getInputDatum(i))); } return getNNearest(number, distances); } /** * Retrieves the {@link InputDatum} corresponding to the given input names, and sorted by their distance to the * given vector. */ public InputDatum[] getByNameDistanceSorted(double[] vector, Collection<String> inputNames, DistanceMetric metric) throws MetricException { ArrayList<InputDistance> distances = new ArrayList<InputDistance>(); for (String string : inputNames) { distances.add(new InputDistance(metric.distance(vector, getInputDatum(string).getVector().toArray()), getInputDatum(string))); } return getNNearest(distances); } @Override public double[][] getData() { double[][] result = new double[numVectors][dim]; for (int i = 0; i < numVectors; i++) { DoubleMatrix1D v = getInputDatum(i).getVector(); for (int j = 0; j < v.size(); j++) { result[i][j] = v.get(j); } } return result; } @Override public double[][] getData(String className) throws SOMToolboxException { if (classInfo != null) { String[] dataNames = classInfo.getDataNamesInClass(className); double[][] result = new double[dataNames.length][dim]; for (int i = 0; i < dataNames.length; i++) { DoubleMatrix1D v = getInputDatum(dataNames[i]).getVector(); for (int j = 0; j < v.size(); j++) { result[i][j] = v.get(j); } } return result; } else { throw new SOMToolboxException("No class information file loaded!"); } } @Override public void setClassInfo(SOMLibClassInformation classInfo) { this.classInfo = classInfo; } public double[][] getDistanceMatrix() { return distanceMatrix; } @Override public double[][] getDataIntervals() { if (intervals == null) { intervals = new double[dim()][2]; for (int i = 0; i < intervals.length; i++) { double min = Integer.MAX_VALUE; double max = Integer.MIN_VALUE; for (int j = 0; j < numVectors(); j++) { double value = getValue(j, i); if (value > max) { max = value; } if (value < min) { min = value; } } intervals[i][0] = min; intervals[i][1] = max; } // DEBUG info // System.out.println("\n\nnmin/max matrix: "); // for (int i = 0; i < intervals.length; i++) { // System.out.println(VectorTools.printVector(intervals[i])); // } } return intervals; } /** * Returns feature densities statistics of the input data, namely a mapping from the number of input objects a * specific feature is not zero in, to the total number of features with that density . */ public Hashtable<Integer, Integer> getFeatureDensities() { Hashtable<Integer, Integer> densities = new Hashtable<Integer, Integer>(); for (int i = 0; i < numVectors; i++) { int featureDensitiy = getInputDatum(i).getFeatureDensity(); Integer count = densities.get(new Integer(featureDensitiy)); if (count == null) { count = new Integer(1); } else { count = new Integer(count.intValue() + 1); } densities.put(new Integer(featureDensitiy), count); } return densities; } @Override public String[] getLabels() { Set<String> names = nameCache.keySet(); return names.toArray(new String[names.size()]); } @Override public String getLabel(int index) { return dataNames[index]; } @Override public boolean equals(Object obj) { if (obj instanceof AbstractSOMLibSparseInputData) { AbstractSOMLibSparseInputData data = (AbstractSOMLibSparseInputData) obj; Object[][] assertions = { { "dim", data.dim(), dim() }, { "meanVector", data.meanVector, meanVector }, { "data names", data.dataNames, dataNames }, { "meanVec", data.meanVector, meanVector } }; for (int i = 0; i < assertions.length; i++) { if (!assertEqual(assertions[i][0], assertions[i][1], assertions[i][2])) { return false; } } for (int i = 0; i < numVectors(); i++) { if (!assertEqual("input element " + i, data.getInputDatum(i), getInputDatum(i))) { return false; } } return true; } else { return false; } } private boolean assertEqual(Object name, Object i1, Object i2) { boolean equals = false; if (i1 instanceof Object[] && i2 instanceof Object[]) { equals = Arrays.equals((Object[]) i1, (Object[]) i2); } else { equals = i1.equals(i2); } if (!equals) { System.out.println(name + " not equal: " + i1 + "<->" + i2); return false; } else { return true; } } public static AbstractSOMLibSparseInputData create(InputDatum[] inputData, SOMLibClassInformation classInfo) { return new SOMLibSparseInputData(inputData, classInfo); } public static String getFormatName() { return "SOMLib"; } public static String getFileNameSuffix() { return ".vec"; } @Override public String getDataSource() { return source; } }