// Copyright (C) 2014 Guibing Guo
//
// This file is part of LibRec.
//
// LibRec is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// LibRec is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with LibRec. If not, see <http://www.gnu.org/licenses/>.
//
package librec.undefined;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import librec.data.DenseMatrix;
import librec.data.SparseMatrix;
import librec.data.SparseVector;
import librec.ranking.CLiMF;
public class DRM extends CLiMF {
protected double alpha;
public DRM(SparseMatrix rm, SparseMatrix tm, int fold) {
super(rm, tm, fold);
algoName = "DRMPlus";
alpha = RecUtils.getMKey(params, "val.diverse.alpha");
initStd = 0.1;
}
@Override
protected void buildModel() {
for (int iter = 1; iter <= numIters; iter++) {
loss = 0;
errs = 0;
for (int u = 0; u < numUsers; u++) {
// all user u's ratings
SparseVector uv = trainMatrix.row(u);
int[] items = uv.getIndex();
double w = Math.sqrt(uv.getCount());
// compute sgd for user u
double[] sgds = new double[numFactors];
for (int f = 0; f < numFactors; f++) {
double sgd = -regU * P.get(u, f);
for (int j : items) {
double fuj = predict(u, j);
double qjf = Q.get(j, f);
sgd += g(-fuj) * qjf;
for (int k : items) {
if (k == j)
continue;
double fuk = predict(u, k);
double qif = Q.get(k, f);
double x = fuk - fuj;
sgd += gd(x) / (1 - g(x)) * (qjf - qif);
}
}
sgds[f] = sgd;
}
// compute sgds for items rated by user u
Map<Integer, List<Double>> itemSgds = new HashMap<>();
for (int j = 0; j < numItems; j++) {
double fuj = predict(u, j);
List<Double> jSgds = new ArrayList<>();
for (int f = 0; f < numFactors; f++) {
double puf = P.get(u, f);
double qjf = Q.get(j, f);
double yuj = uv.contains(j) ? 1.0 : 0.0;
double sgd = yuj * g(-fuj) * puf - regI * qjf;
for (int k : items) {
if (k == j)
continue;
double fuk = predict(u, k);
double x = fuk - fuj;
sgd += gd(-x) * (1.0 / (1 - g(x)) - 1.0 / (1 - g(-x))) * puf;
double qkf = Q.get(k, f);
double sji = DenseMatrix.rowMult(Q, j, Q, k);
double sgd_d = 2 * (1 - sji) * (qjf - qkf) - qkf * Math.pow(qjf - qkf, 2);
sgd += 0.5 * alpha * sgd_d / w;
}
jSgds.add(sgd);
}
itemSgds.put(j, jSgds);
}
// update factors
for (int f = 0; f < numFactors; f++)
P.add(u, f, lRate * sgds[f]);
for (int j = 0; j < numItems; j++) {
List<Double> jSgds = itemSgds.get(j);
for (int f = 0; f < numFactors; f++)
Q.add(j, f, lRate * jSgds.get(f));
}
// compute loss
for (int j = 0; j < numItems; j++) {
if (uv.contains(j)) {
double fuj = predict(u, j);
double ruj = trainMatrix.get(u, j);
errs += (ruj - fuj) * (ruj - fuj);
loss += Math.log(g(fuj));
for (int i : items) {
double fui = predict(u, i);
loss += Math.log(1 - g(fui - fuj));
double sji = DenseMatrix.rowMult(Q, j, Q, i);
double sum = 0;
for (int f = 0; f < numFactors; f++)
sum += Math.pow(Q.get(j, f) - Q.get(i, f), 2);
loss += 0.5 * alpha * (1 - sji) * sum / w;
}
}
for (int f = 0; f < numFactors; f++) {
double puf = P.get(u, f);
double qjf = Q.get(j, f);
loss += -0.5 * (regU * puf * puf + regI * qjf * qjf);
}
}
}
errs *= 0.5;
if (isConverged(iter))
break;
}// end of training
}
@Override
public String toString() {
return super.toString() + "," + (float) alpha;
}
}