package happy.research.cf;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
/**
* Daniel Lemire: A simple implementation of the weighted slope one
* algorithm in Java for item-based collaborative filtering. <br/>
*
* See main function for example. June 1st 2006. <br/>
* Revised by Marco Ponzi on March 29th 2007. <br/>
*
* Revised by Guibing Guo on June 7th, 2013.
*/
public class SlopeOne
{
public static void main(String args[])
{
// this is my data base
Map<String, Map<String, Double>> data = new HashMap<>();
// items
String item_candy = "candy";
String item_dog = "dog";
String item_cat = "cat";
String item_war = "war";
String item_food = "strange food";
mAllItems = new String[] { item_candy, item_dog, item_cat, item_war, item_food };
//I'm going to fill it in
HashMap<String, Double> user1 = new HashMap<>();
HashMap<String, Double> user2 = new HashMap<>();
HashMap<String, Double> user3 = new HashMap<>();
HashMap<String, Double> user4 = new HashMap<>();
user1.put(item_candy, 1.0);
user1.put(item_dog, 0.5);
user1.put(item_war, 0.1);
data.put("Bob", user1);
user2.put(item_candy, 1.0);
user2.put(item_cat, 0.5);
user2.put(item_war, 0.2);
data.put("Jane", user2);
user3.put(item_candy, 0.9);
user3.put(item_dog, 0.4);
user3.put(item_cat, 0.5);
user3.put(item_war, 0.1);
data.put("Jo", user3);
user4.put(item_candy, 0.1);
user4.put(item_war, 1.0);
user4.put(item_food, 0.4);
data.put("StrangeJo", user4);
// next, I create my predictor engine
SlopeOne so = new SlopeOne(data);
System.out.println("Here's the data I have accumulated...");
so.printData();
// then, I'm going to test it out...
HashMap<String, Double> user = new HashMap<>();
System.out.println("Ok, now we predict...");
user.put(item_food, 0.4);
System.out.println("Inputting...");
SlopeOne.print(user);
System.out.println("Getting...");
SlopeOne.print(so.predict(user));
//
user.put(item_war, 0.2);
System.out.println("Inputting...");
SlopeOne.print(user);
System.out.println("Getting...");
SlopeOne.print(so.predict(user));
}
Map<String, Map<String, Double>> mData;
Map<String, Map<String, Double>> diffMatrix;
Map<String, Map<String, Integer>> freqMatrix;
static String[] mAllItems;
public SlopeOne(Map<String, Map<String, Double>> data)
{
mData = data;
buildDiffMatrix();
}
/**
* Based on existing data, and using weights,
* try to predict all missing ratings.
* The trick to make this more scalable is to consider
* only mDiffMatrix entries having a large (>1) mFreqMatrix
* entry.
*
* It will output the prediction 0 when no prediction is possible.
*/
public Map<String, Double> predict(Map<String, Double> user)
{
HashMap<String, Double> predictions = new HashMap<>();
HashMap<String, Integer> frequencies = new HashMap<>();
for (String j : diffMatrix.keySet())
{
frequencies.put(j, 0);
predictions.put(j, 0.0);
}
for (String j : user.keySet())
{
for (String k : diffMatrix.keySet())
{
try
{
Double newval = (diffMatrix.get(k).get(j) + user.get(j)) * freqMatrix.get(k).get(j).intValue();
predictions.put(k, predictions.get(k) + newval);
frequencies.put(k, frequencies.get(k) + freqMatrix.get(k).get(j).intValue());
} catch (NullPointerException e)
{}
}
}
HashMap<String, Double> cleanpredictions = new HashMap<>();
for (String j : predictions.keySet())
{
if (frequencies.get(j) > 0)
{
cleanpredictions.put(j, predictions.get(j) / frequencies.get(j).intValue());
}
}
for (String j : user.keySet())
{
cleanpredictions.put(j, user.get(j));
}
return cleanpredictions;
}
/**
* Based on existing data, and not using weights,
* try to predict all missing ratings.
* The trick to make this more scalable is to consider
* only mDiffMatrix entries having a large (>1) mFreqMatrix
* entry.
*/
public Map<String, Double> weightlesspredict(Map<String, Double> user)
{
HashMap<String, Double> predictions = new HashMap<>();
HashMap<String, Integer> frequencies = new HashMap<>();
for (String j : diffMatrix.keySet())
{
predictions.put(j, 0.0);
frequencies.put(j, 0);
}
for (String j : user.keySet())
{
for (String k : diffMatrix.keySet())
{
//System.out.println("Average diff between "+j+" and "+ k + " is "+mDiffMatrix.get(k).get(j).floatValue()+" with n = "+mFreqMatrix.get(k).get(j).floatValue());
Double newval = (diffMatrix.get(k).get(j) + user.get(j));
predictions.put(k, predictions.get(k) + newval);
}
}
for (String j : predictions.keySet())
{
predictions.put(j, predictions.get(j) / user.size());
}
for (String j : user.keySet())
{
predictions.put(j, user.get(j));
}
return predictions;
}
public void printData()
{
for (String user : mData.keySet())
{
System.out.println(user);
print(mData.get(user));
}
for (int i = 0; i < mAllItems.length; i++)
{
System.out.print("\n" + mAllItems[i] + ":");
printMatrixes(diffMatrix.get(mAllItems[i]), freqMatrix.get(mAllItems[i]));
}
}
private void printMatrixes(Map<String, Double> ratings, Map<String, Integer> frequencies)
{
for (int j = 0; j < mAllItems.length; j++)
{
System.out.format("%10.3f", ratings.get(mAllItems[j]));
System.out.print(" ");
System.out.format("%10d", frequencies.get(mAllItems[j]));
}
System.out.println();
}
public static void print(Map<String, Double> user)
{
for (String j : user.keySet())
{
System.out.println(" " + j + " --> " + user.get(j).floatValue());
}
}
public void buildDiffMatrix()
{
diffMatrix = new HashMap<>();
freqMatrix = new HashMap<>();
// first iterate through users
for (Map<String, Double> user : mData.values())
{
// then iterate through user data
for (Entry<String, Double> entry : user.entrySet())
{
String i1 = entry.getKey();
double r1 = entry.getValue();
if (!diffMatrix.containsKey(i1))
{
diffMatrix.put(i1, new HashMap<String, Double>());
freqMatrix.put(i1, new HashMap<String, Integer>());
}
for (Entry<String, Double> entry2 : user.entrySet())
{
String i2 = entry2.getKey();
double r2 = entry2.getValue();
int cnt = 0;
if (freqMatrix.get(i1).containsKey(i2)) cnt = freqMatrix.get(i1).get(i2);
double diff = 0.0;
if (diffMatrix.get(i1).containsKey(i2)) diff = diffMatrix.get(i1).get(i2);
double new_diff = r1 - r2;
freqMatrix.get(i1).put(i2, cnt + 1);
diffMatrix.get(i1).put(i2, diff + new_diff);
}
}
}
for (String j : diffMatrix.keySet())
{
for (String i : diffMatrix.get(j).keySet())
{
Double oldvalue = diffMatrix.get(j).get(i);
int count = freqMatrix.get(j).get(i).intValue();
diffMatrix.get(j).put(i, oldvalue / count);
}
}
}
}