// Copyright (C) 2014 Guibing Guo // // This file is part of LibRec. // // LibRec is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // LibRec is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with LibRec. If not, see <http://www.gnu.org/licenses/>. // package librec.undefined; import java.util.ArrayList; import java.util.List; import librec.data.DenseMatrix; import librec.data.MatrixEntry; import librec.data.SparseMatrix; import librec.data.SparseVector; public class DMF extends BaseMF { // diversity parameter private float alpha; public DMF(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) { super(trainMatrix, testMatrix, fold); algoName = "DMF"; alpha = RecUtils.getMKey(params, "val.diverse.alpha"); } @Override protected void buildModel() { for (int iter = 1; iter <= numIters; iter++) { loss = 0; errs = 0; for (MatrixEntry me : trainMatrix) { int u = me.row(); // user int j = me.column(); // item double ruj = me.get(); if (ruj <= 0.0) continue; double pred = predict(u, j); double euj = ruj - pred; errs += euj * euj; loss += euj * euj; // update bias factors double bu = userBias.get(u); double sgd = euj - regU * bu; userBias.add(u, lRate * sgd); loss += regU * bu * bu; double bj = itemBias.get(j); sgd = euj - regI * bj; itemBias.add(j, lRate * sgd); loss += regI * bj * bj; // rated items by user u SparseVector uv = trainMatrix.row(u, j); List<Integer> items = new ArrayList<>(); for (int i : uv.getIndex()) { if (i != j) { double sji = DenseMatrix.rowMult(P, j, Q, i); if (sji > minSim) items.add(i); } } double w = Math.sqrt(items.size()); // compute P's gradients double[] sgds = new double[numFactors]; for (int f = 0; f < numFactors; f++) { double pjf = P.get(j, f); sgds[f] = -regU * pjf; double sum_q = 0.0, sum_s = 0.0; for (int i : items) { double qif = Q.get(i, f); double pif = P.get(i, f); sum_q += qif; double sji = DenseMatrix.rowMult(P, j, Q, i); sum_s += 2 * (1 - sji) * (pjf - pif) - qif * Math.pow(pjf - pif, 2); } if (w > 0) sgds[f] += euj * (sum_q / w) + 0.5 * alpha * (sum_s / w); loss += regU * pjf * pjf; } // update Q's factors for (int i : items) { for (int f = 0; f < numFactors; f++) { double pjf = P.get(j, f); double qif = Q.get(i, f); sgd = euj * pjf - regI * qif; sgd += -0.5 * alpha * pjf * Math.pow(pjf - P.get(i, f), 2); Q.add(i, f, lRate * sgd); loss += regI * qif * qif; } } // update P's factors for (int f = 0; f < numFactors; f++) P.add(j, f, lRate * sgds[f]); } errs *= 0.5; loss *= 0.5; if (isConverged(iter)) break; }// end of training } @Override public String toString() { return super.toString() + "," + (float) alpha; } }