package edu.hawaii.jmotif.text; import java.text.DecimalFormat; import java.util.Arrays; import java.util.HashMap; import java.util.Locale; /** * Implements cosine distance matrix. * * @author psenin * */ public class CosineDistanceMatrix { /** Distance matrix. */ private double[][] distances; /** Row names. */ private String[] rows; private HashMap<String, Integer> keysToIndex = new HashMap<String, Integer>(); private static final String COMMA = ","; private static final String CR = "\n"; private static final DecimalFormat df = new DecimalFormat("#0.00000"); /** * Builds a distance matrix. * * @param tfidf The data to use. */ public CosineDistanceMatrix(HashMap<String, HashMap<String, Double>> tfidf) { Locale.setDefault(Locale.US); rows = tfidf.keySet().toArray(new String[0]); Arrays.sort(rows); distances = new double[rows.length][rows.length]; for (int i = 0; i < rows.length; i++) { keysToIndex.put(rows[i], i); for (int j = 0; j < i; j++) { HashMap<String, Double> vectorA = tfidf.get(rows[i]); HashMap<String, Double> vectorB = tfidf.get(rows[j]); double distance = TextUtils.cosineDistance(vectorA, vectorB); distances[i][j] = distance; } } } /** * Get all the row names - i.e. keys. * * @return */ public String[] getRows() { return this.rows; } /** * Get the distances as matrix. * * @return */ public double[][] getDistances() { return this.distances; } /** * Prints matrix. */ @Override public String toString() { StringBuffer sb = new StringBuffer(); sb.append("\"\","); for (String s : rows) { sb.append("\"").append(s).append("\"").append(COMMA); } sb.delete(sb.length() - 1, sb.length()).append(CR); for (int i = 0; i < rows.length; i++) { sb.append("\"").append(rows[i]).append("\","); for (int j = 0; j < rows.length; j++) { sb.append(df.format(distances[i][j])).append(COMMA); } sb.delete(sb.length() - 1, sb.length()).append(CR); } return sb.toString(); } /** * get the distance value between two keys. * * @param keyA first key. * @param keyB second key. * @return the distance between vectors. */ public double distanceBetween(String keyA, String keyB) { if (keysToIndex.get(keyA) >= keysToIndex.get(keyB)) { return distances[keysToIndex.get(keyA)][keysToIndex.get(keyB)]; } return distances[keysToIndex.get(keyB)][keysToIndex.get(keyA)]; } /** * This will subtract all distance values from 1 - so distance becomes inversed - good for * clustering. */ public void transformForHC() { for (int i = 0; i < distances.length; i++) { for (int j = 0; j < distances[0].length; j++) { distances[i][j] = 1.0D - distances[i][j]; } } } }