/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.clustering.minhash; import java.text.NumberFormat; import java.util.Collection; import java.util.List; import java.util.Random; import java.util.Set; import java.util.HashSet; import com.google.common.collect.Lists; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.Text; import org.apache.mahout.common.Pair; import org.apache.mahout.common.RandomUtils; import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; import org.apache.mahout.math.Vector; import org.apache.mahout.math.VectorWritable; public final class LastfmClusterEvaluator { private LastfmClusterEvaluator() { } /* Calculate used JVM memory */ private static String usedMemory() { Runtime runtime = Runtime.getRuntime(); return "Used Memory: [" + (runtime.totalMemory() - runtime.freeMemory()) / (1024 * 1024) + " MB] "; } /** * Computer Jaccard coefficient over two sets. (A intersect B) / (A union B) */ private static double computeSimilarity(Iterable<Integer> listenerVector1, Iterable<Integer> listenerVector2) { Set<Integer> first = new HashSet<Integer>(); for (Integer ele : listenerVector1) { first.add(ele); } Collection<Integer> second = new HashSet<Integer>(); for (Integer ele : listenerVector2) { second.add(ele); } Collection<Integer> intersection = new HashSet<Integer>(first); intersection.retainAll(second); double intersectSize = intersection.size(); first.addAll(second); double unionSize = first.size(); return unionSize == 0 ? 0.0 : intersectSize / unionSize; } /** * Calculate the overall cluster precision by sampling clusters. Precision is * calculated as follows :- * * 1. For a sample of all the clusters calculate the pair-wise similarity * (Jaccard coefficient) for items in the same cluster. * * 2. Count true positives as items whose similarity is above specified * threshold. * * 3. Precision = (true positives) / (total items in clusters sampled). * * @param clusterFile * The file containing cluster information * @param threshold * Similarity threshold for containing two items in a cluster to be * relevant. Must be between 0.0 and 1.0 * @param samplePercentage * Percentage of clusters to sample. Must be between 0.0 and 1.0 */ private static void testPrecision(Path clusterFile, double threshold, double samplePercentage) { Configuration conf = new Configuration(); Random rand = RandomUtils.getRandom(); Text prevCluster = new Text(); List<List<Integer>> listenerVectors = Lists.newArrayList(); long similarListeners = 0; long allListeners = 0; int clustersProcessed = 0; for (Pair<Text,VectorWritable> record : new SequenceFileIterable<Text,VectorWritable>(clusterFile, true, conf)) { Text cluster = record.getFirst(); VectorWritable point = record.getSecond(); if (!cluster.equals(prevCluster)) { // We got a new cluster prevCluster.set(cluster.toString()); // Should we check previous cluster ? if (rand.nextDouble() > samplePercentage) { listenerVectors.clear(); continue; } int numListeners = listenerVectors.size(); allListeners += numListeners; for (int i = 0; i < numListeners; i++) { List<Integer> listenerVector1 = listenerVectors.get(i); for (int j = i + 1; j < numListeners; j++) { List<Integer> listenerVector2 = listenerVectors.get(j); double similarity = computeSimilarity(listenerVector1, listenerVector2); similarListeners += similarity >= threshold ? 1 : 0; } } listenerVectors.clear(); clustersProcessed++; System.out.print('\r' + usedMemory() + " Clusters processed: " + clustersProcessed); } List<Integer> listeners = Lists.newArrayList(); for (Vector.Element ele : point.get()) { listeners.add((int) ele.get()); } listenerVectors.add(listeners); } System.out.println("\nTest Results"); System.out.println("============="); System.out.println(" (A) Listeners in same cluster with simiarity above threshold (" + threshold + ") : " + similarListeners); System.out.println(" (B) All listeners: " + allListeners); NumberFormat format = NumberFormat.getInstance(); format.setMaximumFractionDigits(2); double precision = (double) similarListeners / allListeners * 100.0; System.out.println(" Average cluster precision: A/B = " + format.format(precision)); } public static void main(String[] args) { if (args.length < 3) { System.out.println("LastfmClusterEvaluation <cluster-file> <threshold> <sample-percentage>"); System.out.println(" <cluster-file>: Absolute Path of file containing cluster information in DEBUG format"); System.out.println(" <threshold>: Minimum threshold for jaccard co-efficient for considering two items"); System.out.println(" in a cluster to be really similar. Should be between 0.0 and 1.0"); System.out.println(" <sample-percentage>: Percentage of clusters to sample. Should be between 0.0 and 1.0"); return; } Path clusterFile = new Path(args[0]); double threshold = Double.parseDouble(args[1]); double samplePercentage = Double.parseDouble(args[2]); testPrecision(clusterFile, threshold, samplePercentage); } }