// Copyright (C) 2014 Guibing Guo
//
// This file is part of LibRec.
//
// LibRec is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// LibRec is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with LibRec. If not, see <http://www.gnu.org/licenses/>.
//
package librec.undefined;
import java.io.BufferedReader;
import java.util.HashMap;
import java.util.Map;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import happy.coding.io.FileIO;
import happy.coding.io.Logs;
import happy.coding.io.Strings;
import happy.coding.math.Randoms;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.MatrixEntry;
import librec.data.RatingContext;
import librec.data.SparseMatrix;
import librec.data.SparseVector;
import librec.data.VectorEntry;
import librec.intf.ContextRecommender;
/**
* Koren, <strong>Collaborative Filtering with Temporal Dynamics</strong>, KDD 2009.
*
*
* <p>Thank Bin Wu for sharing a version of timeSVD++ source code.</p>
*
* @author guoguibing
*
*/
public class TimeSVDPlusPlus extends ContextRecommender {
// million seconds per day
private static long MS_PER_DAY = 24 * 60 * 60 * 1000 * 1000;
// the span of days of rating timestamps
private static int numDays;
// {user, mean date}
private static Map<Integer, Long> dateMean;
// minimum/maximum rating timestamp
private static long min, max;
private static float beta;
// number of bins for all the items
private static int numBins;
// item's implicit influence
private DenseMatrix Y;
// {item, bin(t)} bias matrix
private DenseMatrix Bit;
// {user, day, bias} table
private Table<Integer, Integer, Double> But;
// user bias parameters
private DenseVector userAlpha;
// {user, day, bias} table
private Table<Integer, Integer, Double> Cut;
// user scaling
private DenseVector userScaling;
// {user, feature} alpha matrix
private DenseMatrix Auf;
// {user, {feature, day, value} } map
private Map<Integer, Table<Integer, Integer, Double>> Puft;
// read context information
static {
try {
readContext();
preset();
} catch (Exception e) {
e.printStackTrace();
}
}
public TimeSVDPlusPlus(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
super(trainMatrix, testMatrix, fold);
algoName = "timeSVD++";
}
@Override
protected void initModel() throws Exception {
super.initModel();
userBias = new DenseVector(numUsers);
userBias.init();
itemBias = new DenseVector(numItems);
itemBias.init();
userAlpha = new DenseVector(numUsers);
userAlpha.init();
Bit = new DenseMatrix(numItems, numBins);
Bit.init();
userScaling = new DenseVector(numUsers); // cu
userScaling.init();
Y = new DenseMatrix(numItems, numFactors);
Y.init();
Auf = new DenseMatrix(numUsers, numFactors);
Auf.init();
But = HashBasedTable.create();
Cut = HashBasedTable.create();
Puft = new HashMap<>();
}
@Override
protected void buildModel() throws Exception {
for (int iter = 1; iter <= numIters; iter++) {
errs = 0;
loss = 0;
for (MatrixEntry me : trainMatrix) {
int u = me.row();
int i = me.column();
double rui = me.get();
if (rui <= 0)
continue;
long t = ratingContexts.get(u, i).getTimestamp();
int bin = bin(t);
int day = day(t);
double dev = dev(u, t);
double bi = itemBias.get(i);
double bit = Bit.get(i, bin);
double bu = userBias.get(u);
double but = But.get(u, day);
double au = userAlpha.get(u); // alpha_u
double cu = userScaling.get(u);
double cut = Cut.get(u, day);
double pui = globalMean + (bi + bit) * (cu + cut); // mu + bi(t)
pui += bu + au * dev(u, t) + but; // bu(t)
// qi*yi
SparseVector Ru = trainMatrix.row(u);
double sum_y = 0;
for (VectorEntry ve : Ru) {
int k = ve.index();
sum_y += DenseMatrix.rowMult(Y, k, Q, i);
}
if (Ru.getCount() > 0)
pui += sum_y / Math.pow(Ru.getCount(), -0.5);
// qi*pu(t)
if (!Puft.containsKey(u)) {
Table<Integer, Integer, Double> data = HashBasedTable.create();
Puft.put(u, data);
}
Table<Integer, Integer, Double> data = Puft.get(u);
for (int f = 0; f < numFactors; f++) {
double qif = Q.get(i, f);
if (!data.contains(f, day)) {
// late initialization
data.put(f, day, Randoms.random());
}
double puf = P.get(u, f) + Auf.get(u, f) * dev + data.get(f, day);
pui += puf * qif;
}
double eui = pui - rui;
errs += eui * eui;
loss += eui * eui;
// update bu
double sgd = eui + regB * bu;
userBias.add(u, -lRate * sgd);
// TODO: add codes here to update other variables
}
if (isConverged(iter))
break;
}
}
@Override
protected double predict(int u, int j) {
// retrieve the test rating timestamp
long t = ratingContexts.get(u, j).getTimestamp();
int bin = bin(t);
int day = day(t);
double dev = dev(u, t);
double pred = globalMean;
// bi(t)
pred += (itemBias.get(j) + Bit.get(j, bin)) * (userScaling.get(u) + Cut.get(u, day));
// bu(t)
pred += userBias.get(u) + userAlpha.get(u) * dev + But.get(u, day);
// qi*yi
SparseVector Ru = trainMatrix.row(u);
double sum_y = 0;
for (VectorEntry ve : Ru) {
int k = ve.index();
sum_y += DenseMatrix.rowMult(Y, k, Q, j);
}
if (Ru.getCount() > 0)
pred += sum_y / Math.pow(Ru.getCount(), -0.5);
// qi*pu(t)
if (!Puft.containsKey(u)) {
Table<Integer, Integer, Double> data = HashBasedTable.create();
Puft.put(u, data);
}
Table<Integer, Integer, Double> data = Puft.get(u);
for (int f = 0; f < numFactors; f++) {
double qjf = Q.get(j, f);
double puf = P.get(u, f) + Auf.get(u, f) * dev + (data.contains(f, day) ? data.get(f, day) : 0);
pred += puf * qjf;
}
return pred;
}
@Override
public String toString() {
return super.toString() + "," + Strings.toString(new Object[] { beta, numBins });
}
/**
* Read rating timestamp
*
*/
protected static void readContext() throws Exception {
String contextPath = cf.getPath("dataset.social");
Logs.debug("Context dataset: {}", Strings.last(contextPath, 38));
ratingContexts = HashBasedTable.create();
BufferedReader br = FileIO.getReader(contextPath);
String line = null;
RatingContext rc = null;
min = Long.MAX_VALUE;
max = Long.MIN_VALUE;
while ((line = br.readLine()) != null) {
String[] data = line.split("[ \t,]");
String user = data[0];
String item = data[1];
long timestamp = Long.parseLong(data[3]);
int userId = rateDao.getUserId(user);
int itemId = rateDao.getItemId(item);
rc = new RatingContext(userId, itemId);
rc.setTimestamp(timestamp);
ratingContexts.put(userId, itemId, rc);
if (min > timestamp)
min = timestamp;
if (max < timestamp)
max = timestamp;
}
numDays = days(max - min);
}
/**
* preset for the timeSVD++ model
*/
protected static void preset() {
beta = cf.getFloat("timeSVD++.beta");
numBins = cf.getInt("timeSVD++.item.bins");
// compute user's mean of rating timestamps
dateMean = new HashMap<>();
for (int u = 0; u < numUsers; u++) {
Map<Integer, RatingContext> rcs = ratingContexts.row(u);
long sum = 0;
for (RatingContext rc : rcs.values()) {
sum += rc.getTimestamp();
}
dateMean.put(u, sum / rcs.size());
}
}
/***************************************************************** Functional Methods *******************************************/
/**
* @return the time deviation for a specific timestamp t w.r.t the mean date tu
*/
protected double dev(int u, long t) {
long tu = dateMean.get(u);
// date difference in millionseconds;
long diff = t - tu;
// date difference in days
int days = days(Math.abs(diff));
return Math.signum(diff) * Math.pow(days, beta);
}
/**
* @return the bin number (starting from 0) for a specific timestamp t;
*/
protected static int bin(long t) {
return day(t) * numBins / numDays;
}
/**
* @return the number of days since the earliest one
*/
protected static int day(long t) {
return days(t - min);
}
/**
* @return the number of days for a timestamp difference
*/
private static int days(long diff) {
return (int) (diff / MS_PER_DAY);
}
}