/*
* Copyright 2008-2013, ETH Zürich, Samuel Welten, Michael Kuhn, Tobias Langner,
* Sandro Affentranger, Lukas Bossard, Michael Grob, Rahul Jain,
* Dominic Langenegger, Sonia Mayor Alonso, Roger Odermatt, Tobias Schlueter,
* Yannick Stucki, Sebastian Wendland, Samuel Zehnder, Samuel Zihlmann,
* Samuel Zweifel
*
* This file is part of Jukefox.
*
* Jukefox is free software: you can redistribute it and/or modify it under the
* terms of the GNU General Public License as published by the Free Software
* Foundation, either version 3 of the License, or any later version. Jukefox is
* distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
* without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
* PARTICULAR PURPOSE. See the GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License along with
* Jukefox. If not, see <http://www.gnu.org/licenses/>.
*/
package ch.ethz.dcg.jukefox.commons.utils;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
/**
* Class to represent the music taste of a subject based on a collection of
* songs or artists. It works by taking a set of weighted music similarity
* coordinates that fit the music taste of the subject. It then trains a model
* based on kmeans that represents these coordinates. Once the music taste is
* defined one can query a point in the music similarity space and get a rating
* for it.
*
* @author swelten
*
*/
public class SimpleMusicTaste {
public static final String TAG = SimpleMusicTaste.class.getSimpleName();
public static final int DEFAULT_NUM_ITERATIONS = 100;
private float[][] classCenters;
private float[] classCentersDist;
private float[] classCentersNum;
private int[] assignedClassCenter;
private float[] weightSums;
private int numCenters;
private float maxDist = 0;
/**
* Construct and initialize a music taste. May take some time, depending on
* the number of coordinates and maxNumCenters
*
* @param weightedPreferences
* A List of points in the music similarity space the subject
* likes. The higher the weight of a coordinate the more
* important it is.
* @param maxNumCenters
* The maximal number of classes kmeans uses to model the taste
*/
public SimpleMusicTaste(List<Pair<float[], Integer>> weightedPreferences, int maxNumCenters) {
if (weightedPreferences.size() == 0) {
return;
}
classCenters = new float[maxNumCenters][weightedPreferences.get(0).first.length];
assignedClassCenter = new int[weightedPreferences.size()];
weightSums = new float[maxNumCenters];
classCentersDist = new float[maxNumCenters];
classCentersNum = new float[maxNumCenters];
// Find optimal number of class centers. Therefore increase the number until we don't make enough progress
double prevMaxDist = Float.MAX_VALUE;
for (int i = 1; i <= maxNumCenters; i++) {
double bestMaxDist = computeMusicTaste(weightedPreferences, i);
Log.v(TAG, "Tried " + i + " classcenters: MaxDist: " + bestMaxDist + " prevMaxDist " + prevMaxDist);
// If we don't improve by 5% => cancel and use the last numClasses value
if (!(bestMaxDist < prevMaxDist - (prevMaxDist * 0.05))) {
i--;
computeMusicTaste(weightedPreferences, i);
break;
}
prevMaxDist = bestMaxDist;
}
}
/**
* Returns the rating of a certain point in the music similarity space based
* on the music taste. A rating of smaller than 1 means that it fits the
* taste. Larger than 1 means that it is outside of the current music taste.
* The larger the value is, the more distant from the current taste it is.
*
* @param position
* The position in the music similarity space that should be
* rated
* @return a rating between 0 and infinity
*/
public float getRating(float[] position) {
float nearestDist = Float.MAX_VALUE;
for (int centerNr = 0; centerNr < numCenters; centerNr++) {
float dist = distance(position, classCenters[centerNr]) / classCentersDist[centerNr];
if (dist < nearestDist) {
nearestDist = dist;
}
}
return nearestDist;
}
private double computeMusicTaste(List<Pair<float[], Integer>> weightedPreferences, int numCenters) {
this.numCenters = numCenters;
int dimensionality = weightedPreferences.get(0).first.length;
float bestMaxDist = Float.MAX_VALUE;
// find best random start class centers
// Try 10 different random class centers and run kmeans with the best centers at the end
for (int i = 0; i < 10; i++) {
performKmeans(weightedPreferences, numCenters, dimensionality, DEFAULT_NUM_ITERATIONS);
if (maxDist < bestMaxDist) {
bestMaxDist = maxDist;
}
}
// Do best kmeans
performKmeans(weightedPreferences, numCenters, dimensionality, DEFAULT_NUM_ITERATIONS);
// printDebugOutput(numCenters);
return bestMaxDist;
}
private void printDebugOutput(int numCenters, List<Pair<float[], Integer>> weightedPreferences) {
Log.v(TAG, "Centers: " + numCenters + ", MaxDist: " + maxDist);
for (int centerNr = 0; centerNr < numCenters; centerNr++) {
Log.v(TAG, "C " + centerNr + ": " + classCentersDist[centerNr] + ", " + classCentersNum[centerNr]);
if (classCentersDist[centerNr] == maxDist) {
for (int prefPos = 0; prefPos < weightedPreferences.size(); prefPos++) {
Pair<float[], Integer> entry = weightedPreferences.get(prefPos);
int centerNum = assignedClassCenter[prefPos];
float dist = distance(entry.first, classCenters[centerNum]);
Log.v(TAG, "D: " + dist);
}
}
}
}
private void performKmeans(List<Pair<float[], Integer>> weightedPreferences, int numCenters, int dimensionality,
int numIterations) {
classCenters = getRandomClassCenters(numCenters, dimensionality, weightedPreferences);
for (int it = 0; it < numIterations; it++) {
doKmeansIteration(weightedPreferences, numCenters, assignedClassCenter, weightSums);
}
// printDebugOutput(numCenters, weightedPreferences);
// Remove Outliers
List<Pair<float[], Integer>> weightedPreferencesCleaned = removeOutliers(weightedPreferences, numCenters,
assignedClassCenter, weightSums);
classCenters = getRandomClassCenters(numCenters, dimensionality, weightedPreferencesCleaned);
for (int it = 0; it < numIterations; it++) {
doKmeansIteration(weightedPreferencesCleaned, numCenters, assignedClassCenter, weightSums);
}
// printDebugOutput(numCenters, weightedPreferences);
}
private List<Pair<float[], Integer>> removeOutliers(List<Pair<float[], Integer>> weightedPreferences,
int numCenters, int[] assignedClassCenter, float[] weightSums) {
List<Pair<float[], Integer>> weightedPreferencesCleaned = new ArrayList<Pair<float[], Integer>>();
double[] means = computeMeanDists(weightedPreferences);
double[] vars = computeVar(weightedPreferences, means);
// for (int i = 0; i < numCenters; i++) {
// System.out.println("Center " + i + ": Dist mean: " + means[i] + ", Dist variance: " + vars[i]);
// }
for (int prefPos = 0; prefPos < weightedPreferences.size(); prefPos++) {
Pair<float[], Integer> entry = weightedPreferences.get(prefPos);
int centerNum = assignedClassCenter[prefPos];
float dist = distance(entry.first, classCenters[centerNum]);
// Only keep it if it is not too far from the center
if (dist < 2 * means[centerNum]) {
weightedPreferencesCleaned.add(entry);
} else {
// Log.v(TAG, "Removed entry with distance: " + dist + " from class center: " + centerNum);
}
}
return weightedPreferencesCleaned;
}
private double[] computeVar(List<Pair<float[], Integer>> weightedPreferences, double[] means) {
double[] vars = new double[numCenters];
for (int prefPos = 0; prefPos < weightedPreferences.size(); prefPos++) {
Pair<float[], Integer> entry = weightedPreferences.get(prefPos);
int centerNum = assignedClassCenter[prefPos];
float dist = distance(entry.first, classCenters[centerNum]);
vars[centerNum] += (dist - means[centerNum]) * (dist - means[centerNum]);
}
for (int i = 0; i < numCenters; i++) {
vars[i] /= classCentersNum[i];
}
return vars;
}
private double[] computeMeanDists(List<Pair<float[], Integer>> weightedPreferences) {
double[] means = new double[numCenters];
for (int prefPos = 0; prefPos < weightedPreferences.size(); prefPos++) {
Pair<float[], Integer> entry = weightedPreferences.get(prefPos);
float dist = distance(entry.first, classCenters[assignedClassCenter[prefPos]]);
means[assignedClassCenter[prefPos]] += dist;
}
for (int i = 0; i < numCenters; i++) {
means[i] /= classCentersNum[i];
}
return means;
}
private void doKmeansIteration(List<Pair<float[], Integer>> weightedPreferences, int numCenters,
int[] assignedClassCenter, float[] weightSums) {
reset(classCentersDist);
reset(classCentersNum);
maxDist = 0;
// Expectation (Assign preferences to class centers
for (int prefPos = 0; prefPos < weightedPreferences.size(); prefPos++) {
Pair<float[], Integer> entry = weightedPreferences.get(prefPos);
int nearestCenter = 0;
float nearestDist = Float.MAX_VALUE;
for (int centerNr = 0; centerNr < numCenters; centerNr++) {
float dist = distance(entry.first, classCenters[centerNr]);
if (dist < nearestDist) {
nearestCenter = centerNr;
nearestDist = dist;
}
}
assignedClassCenter[prefPos] = nearestCenter;
classCentersNum[nearestCenter] += 1;
if (nearestDist > classCentersDist[nearestCenter]) {
classCentersDist[nearestCenter] = nearestDist;
}
if (nearestDist > maxDist) {
maxDist = nearestDist;
}
}
// System.out.println("MaxDist: " + maxDist);
// Maximization (Recompute class centers)
reset(weightSums);
for (int centerNr = 0; centerNr < numCenters; centerNr++) {
reset(classCenters[centerNr]);
}
for (int prefPos = 0; prefPos < weightedPreferences.size(); prefPos++) {
Pair<float[], Integer> entry = weightedPreferences.get(prefPos);
int assignedCenter = assignedClassCenter[prefPos];
add(classCenters[assignedCenter], entry.first, entry.second);
weightSums[assignedCenter] += entry.second;
}
for (int centerNr = 0; centerNr < numCenters; centerNr++) {
divide(classCenters[centerNr], weightSums[centerNr]);
}
}
private float[][] getRandomClassCenters(int numberOfCenters, int dimensionality,
List<Pair<float[], Integer>> weightedPreferences) {
List<Pair<float[], Integer>> potentialCenters = new ArrayList<Pair<float[], Integer>>(weightedPreferences);
float[][] classCenters = new float[numberOfCenters][dimensionality];
Random r = new Random();
for (int i = 0; i < classCenters.length; i++) {
Pair<float[], Integer> center = potentialCenters.remove(r.nextInt(potentialCenters.size()));
for (int u = 0; u < classCenters[0].length; u++) {
classCenters[i][u] = center.first[u];
}
}
return classCenters;
}
private float distance(float[] v1, float[] v2) {
double dist = 0, diff = 0;
for (int i = 0; i < v1.length; i++) {
diff = v1[i] - v2[i];
dist += diff * diff;
}
return (float) Math.sqrt(dist);
}
/**
* return v1 = v1 + weight * v2
*/
private float[] add(float[] v1, float[] v2, float weight) {
for (int i = 0; i < v1.length; i++) {
v1[i] = v1[i] + weight * v2[i];
}
return v1;
}
/**
* return v1 = v1 / divisor
*/
private float[] divide(float[] v1, float divisor) {
for (int i = 0; i < v1.length; i++) {
v1[i] = v1[i] / divisor;
}
return v1;
}
/**
* return v1 = 0
*/
private float[] reset(float[] v1) {
for (int i = 0; i < v1.length; i++) {
v1[i] = 0;
}
return v1;
}
}