// Copyright (C) 2014 Guibing Guo
//
// This file is part of LibRec.
//
// LibRec is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// LibRec is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with LibRec. If not, see <http://www.gnu.org/licenses/>.
//
package librec.undefined;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.MatrixEntry;
import librec.data.SparseMatrix;
import librec.data.SparseVector;
import librec.data.VectorEntry;
import librec.intf.SocialRecommender;
/**
* Our ongoing testing algorithm
*
* @author guoguibing
*
*/
public class TrustSVDPlusPlus extends SocialRecommender {
private DenseMatrix W, Y;
private DenseVector wlr_j, wlr_tc, wlr_tr;
private float alpha = -1;
double delta_a, delta_1_a;
public TrustSVDPlusPlus(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
super(trainMatrix, testMatrix, fold);
if (params.containsKey("val.reg")) {
float reg = RecUtils.getMKey(params, "val.reg");
regB = reg;
regU = reg;
regI = reg;
regS = reg;
} else if (params.containsKey("val.reg.social")) {
regS = RecUtils.getMKey(params, "val.reg.social");
} else if (params.containsKey("TrustSVD++.alpha")) {
alpha = RecUtils.getMKey(params, "TrustSVD++.alpha");
}
if (alpha < 0)
alpha = cf.getFloat("TrustSVD++.alpha");
algoName = "TrustSVD++";
}
@Override
protected void initModel() throws Exception {
super.initModel();
userBias = new DenseVector(numUsers);
itemBias = new DenseVector(numItems);
W = new DenseMatrix(numUsers, numFactors);
Y = new DenseMatrix(numItems, numFactors);
if (initByNorm) {
userBias.init(initMean, initStd);
itemBias.init(initMean, initStd);
W.init(initMean, initStd);
Y.init(initMean, initStd);
} else {
userBias.init();
itemBias.init();
W.init();
Y.init();
}
// weighted lambda regularization (wlr)
wlr_tc = new DenseVector(numUsers);
wlr_tr = new DenseVector(numUsers);
wlr_j = new DenseVector(numItems);
for (int u = 0; u < numUsers; u++) {
int count = socialMatrix.columnSize(u);
wlr_tc.set(u, count > 0 ? 1.0 / Math.sqrt(count) : 1.0);
count = socialMatrix.rowSize(u);
wlr_tr.set(u, count > 0 ? 1.0 / Math.sqrt(count) : 1.0);
}
for (int j = 0; j < numItems; j++) {
int count = trainMatrix.columnSize(j);
wlr_j.set(j, count > 0 ? 1.0 / Math.sqrt(count) : 1.0);
}
delta_a = alpha > 0 ? 1.0 : 0.0;
delta_1_a = 1 - alpha > 0 ? 1.0 : 0.0;
}
protected void buildModel() throws Exception {
for (int iter = 1; iter <= numIters; iter++) {
loss = 0;
errs = 0;
DenseMatrix PS = new DenseMatrix(numUsers, numFactors);
DenseMatrix QS = new DenseMatrix(numItems, numFactors);
DenseMatrix WS = new DenseMatrix(numUsers, numFactors);
// ratings
for (MatrixEntry me : trainMatrix) {
int u = me.row(); // user
int j = me.column(); // item
double ruj = me.get();
if (ruj <= 0.0)
continue;
// To speed up, directly access the prediction
double bu = userBias.get(u), bj = itemBias.get(j);
double pred = globalMean + bu + bj + DenseMatrix.rowMult(P, u, Q, j);
// Y
SparseVector ru = trainMatrix.row(u); // row u
int[] Iu = ru.getIndex(); // rated items
if (ru.getCount() > 0) {
double sum = 0;
for (int i : Iu)
sum += DenseMatrix.rowMult(Y, i, Q, j);
pred += sum / Math.sqrt(ru.getCount());
}
// Tur
SparseVector tr = socialMatrix.row(u); // trustees of user u
int[] tur = tr.getIndex();
if (tr.getCount() > 0) {
double sum = 0.0;
for (int v : tur)
sum += DenseMatrix.rowMult(W, v, Q, j);
pred += alpha * (sum / Math.sqrt(tr.getCount()));
}
// Tuc
SparseVector tc = socialMatrix.column(u); // trusters of user u
int[] tuc = tc.getIndex();
if (tc.getCount() > 0) {
double sum = 0.0;
for (int k : tuc)
sum += DenseMatrix.rowMult(P, k, Q, j);
pred += (1 - alpha) * (sum / Math.sqrt(tc.getCount()));
}
double euj = pred - ruj;
errs += euj * euj;
loss += euj * euj;
// update factors
double reg_u = Iu.length > 0 ? 1.0 / Math.sqrt(Iu.length) : 1.0;
double reg_ur = wlr_tr.get(u);
double reg_uc = wlr_tc.get(u);
double reg_j = wlr_j.get(j);
double sgd = euj + regB * reg_u * bu;
userBias.add(u, -lRate * sgd);
sgd = euj + regB * reg_j * bj;
itemBias.add(j, -lRate * sgd);
loss += regB * reg_u * bu * bu;
loss += regB * reg_j * bj * bj;
double[] sum_ys = new double[numFactors];
for (int f = 0; f < numFactors; f++) {
double sum = 0;
for (int i : Iu)
sum += Y.get(i, f);
sum_ys[f] = reg_u * sum;
}
double[] sum_trs = new double[numFactors];
for (int f = 0; f < numFactors; f++) {
double sum = 0;
for (int v : tur)
sum += W.get(v, f);
sum_trs[f] = reg_ur * sum;
}
double[] sum_tcs = new double[numFactors];
for (int f = 0; f < numFactors; f++) {
double sum = 0;
for (int k : tuc)
sum += P.get(k, f);
sum_tcs[f] = reg_uc * sum;
}
for (int f = 0; f < numFactors; f++) {
double puf = P.get(u, f);
double qjf = Q.get(j, f);
double sgd_u = regU * reg_u + regS * (delta_a * reg_ur + delta_1_a * reg_uc);
double delta_u = euj * qjf + sgd_u * puf;
double delta_j = euj * (puf + sum_ys[f] + alpha * sum_trs[f] + (1 - alpha) * sum_tcs[f]) + regI
* reg_j * qjf;
PS.add(u, f, delta_u);
QS.add(j, f, delta_j);
loss += sgd_u * puf * puf + regI * reg_j * qjf * qjf;
// update Y
for (int i : Iu) {
double yif = Y.get(i, f);
double reg_yi = wlr_j.get(i);
double delta_y = euj * reg_u * qjf + regI * reg_yi * yif;
Y.add(i, f, -lRate * delta_y);
loss += regI * reg_yi * yif * yif;
}
// update W
for (int v : tur) {
double wvf = W.get(v, f);
double reg_vr = wlr_tr.get(v);
double sgd_v = regU * delta_a * reg_vr;
double delta_v = euj * alpha * reg_ur * qjf + sgd_v * wvf;
WS.add(v, f, delta_v);
loss += sgd_v * wvf * wvf;
}
// update Pkf
for (int k : tuc) {
double pkf = P.get(k, f);
double reg_kc = wlr_tc.get(k);
double sgd_k = regU * delta_1_a * reg_kc;
double delta_k = euj * (1 - alpha) * reg_uc * qjf + sgd_k * pkf;
PS.add(k, f, delta_k);
loss += sgd_k * pkf * pkf;
}
}
}
// trust
for (int u = 0; u < numUsers; u++) {
SparseVector tr = socialMatrix.row(u);
SparseVector tc = socialMatrix.column(u);
for (int f = 0; f < numFactors; f++) {
// wvf
for (VectorEntry ve : tr) {
int v = ve.index();
double tuv = ve.get();
double puv = DenseMatrix.rowMult(P, u, W, v);
double euv = puv - tuv;
double cmg = regS * alpha;
PS.add(u, f, cmg * euv * W.get(v, f));
WS.add(v, f, cmg * euv * P.get(u, f));
loss += cmg * euv * euv;
}
// pkf
for (VectorEntry ve : tc) {
int k = ve.index();
double tku = ve.get();
double pku = DenseMatrix.rowMult(P, k, W, u);
double eku = pku - tku;
double cmg = regS * (1 - alpha);
PS.add(k, f, cmg * eku * W.get(u, f));
loss += cmg * eku * eku;
}
}
}
P = P.add(PS.scale(-lRate));
Q = Q.add(QS.scale(-lRate));
W = W.add(WS.scale(-lRate));
errs *= 0.5;
loss *= 0.5;
if (isConverged(iter))
break;
}// end of training
}
@Override
protected double predict(int u, int j) {
double pred = globalMean + userBias.get(u) + itemBias.get(j) + DenseMatrix.rowMult(P, u, Q, j);
// Y
SparseVector uv = trainMatrix.row(u);
if (uv.getCount() > 0) {
double sum = 0;
for (int i : uv.getIndex())
sum += DenseMatrix.rowMult(Y, i, Q, j);
pred += sum / Math.sqrt(uv.getCount());
}
// Tur: Tu row
SparseVector tr = socialMatrix.row(u);
if (tr.getCount() > 0) {
double sum = 0.0;
for (int v : tr.getIndex())
sum += DenseMatrix.rowMult(W, v, Q, j);
pred += alpha * (sum / Math.sqrt(tr.getCount()));
}
// Tuc: Tu column
SparseVector tc = socialMatrix.column(u);
if (tc.getCount() > 0) {
double sum = 0.0;
for (int k : tc.getIndex())
sum += DenseMatrix.rowMult(P, k, Q, j);
pred += (1 - alpha) * (sum / Math.sqrt(tc.getCount()));
}
return pred;
}
@Override
public String toString() {
return super.toString() + "," + alpha;
}
}