// Copyright (C) 2014 Guibing Guo
//
// This file is part of LibRec.
//
// LibRec is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// LibRec is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with LibRec. If not, see <http://www.gnu.org/licenses/>.
//
package librec.undefined;
import java.util.ArrayList;
import java.util.List;
import librec.data.DenseMatrix;
import librec.data.DenseVector;
import librec.data.MatrixEntry;
import librec.data.SparseMatrix;
import librec.data.SparseVector;
import librec.intf.IterativeRecommender;
public class BaseMF extends IterativeRecommender {
protected boolean isPosOnly;
protected double minSim;
public BaseMF(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) {
super(trainMatrix, testMatrix, fold);
algoName = "BaseMF";
isPosOnly = cf.isOn("is.similarity.pos");
minSim = isPosOnly ? 0.0 : Double.NEGATIVE_INFINITY;
}
@Override
protected void initModel() {
// re-use it as another item-factor matrix
P = new DenseMatrix(numItems, numFactors);
Q = new DenseMatrix(numItems, numFactors);
// initialize model
if (isPosOnly) {
P.init(0.01);
Q.init(0.01);
} else {
P.init(initMean, initStd);
Q.init(initMean, initStd);
}
// set to 0 for items without any ratings
for (int j = 0, jm = numItems; j < jm; j++) {
if (trainMatrix.columnSize(j) == 0) {
P.setRow(j, 0.0);
Q.setRow(j, 0.0);
}
}
userBias = new DenseVector(numUsers);
itemBias = new DenseVector(numItems);
// initialize user bias
userBias.init(initMean, initStd);
itemBias.init(initMean, initStd);
}
@Override
protected void buildModel() {
for (int iter = 1; iter <= numIters; iter++) {
loss = 0;
errs = 0;
for (MatrixEntry me : trainMatrix) {
int u = me.row(); // user
int j = me.column(); // item
double ruj = me.get(); // rate
if (ruj <= 0.0)
continue;
double pred = predict(u, j);
double euj = ruj - pred;
errs += euj * euj;
loss += euj * euj;
// update bias factors
double bu = userBias.get(u);
double sgd = euj - regU * bu;
userBias.add(u, lRate * sgd);
loss += regU * bu * bu;
double bj = itemBias.get(j);
sgd = euj - regI * bj;
itemBias.add(j, lRate * sgd);
loss += regI * bj * bj;
// rated items by user u
SparseVector uv = trainMatrix.row(u, j);
List<Integer> items = new ArrayList<>();
for (int i : uv.getIndex()) {
if (i != j) {
double sji = DenseMatrix.rowMult(P, j, Q, i);
if (sji > minSim)
items.add(i);
}
}
double w = Math.sqrt(items.size());
// compute P's gradients
double[] sgds = new double[numFactors];
for (int f = 0; f < numFactors; f++) {
double pjf = P.get(j, f);
double sum = 0.0;
for (int i : items)
sum += Q.get(i, f);
sgds[f] = euj * (w > 0.0 ? sum / w : 0.0) - regU * pjf;
loss += regU * pjf * pjf;
}
// update Q's factors
for (int i : items) {
for (int f = 0; f < numFactors; f++) {
double pjf = P.get(j, f);
double qif = Q.get(i, f);
sgd = euj * pjf - regI * qif;
Q.add(i, f, lRate * sgd);
loss += regI * qif * qif;
}
}
// update P's factors
for (int f = 0; f < numFactors; f++)
P.add(j, f, lRate * sgds[f]);
}
errs *= 0.5;
loss *= 0.5;
if (isConverged(iter))
break;
}// end of training
}
@Override
protected double predict(int u, int j) {
double pred = userBias.get(u) + itemBias.get(j);
int k = 0;
double sum = 0.0f;
SparseVector uv = trainMatrix.row(u);
for (int i : uv.getIndex()) {
if (i != j) {
double sji = DenseMatrix.rowMult(P, j, Q, i);
if (sji > minSim) {
sum += sji;
k++;
}
}
}
if (k > 0)
pred += sum / Math.sqrt(k);
return pred;
}
@Override
public String toString() {
return super.toString() + "," + isPosOnly;
}
}