// Copyright (C) 2014 Guibing Guo // // This file is part of LibRec. // // LibRec is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // LibRec is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with LibRec. If not, see <http://www.gnu.org/licenses/>. // package librec.undefined; import librec.data.DenseMatrix; import librec.data.DenseVector; import librec.data.MatrixEntry; import librec.data.SparseMatrix; import librec.data.SparseVector; import librec.intf.SocialRecommender; /** * Our ongoing testing algorithm * * @author guoguibing * */ public class TrustSVD2 extends SocialRecommender { private DenseMatrix W, Y; private DenseVector wlr_j, wlr_tc, wlr_tr; public TrustSVD2(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) { super(trainMatrix, testMatrix, fold); algoName = "TrustSVD++"; if (params.containsKey("val.reg")) { double reg = RecUtils.getMKey(params, "val.reg"); regU = (float) reg; regI = (float) reg; regS = (float)reg; } else { regS = (float)RecUtils.getMKey(params, "val.reg.social"); } } @Override protected void initModel() throws Exception { super.initModel(); userBias = new DenseVector(numUsers); itemBias = new DenseVector(numItems); W = new DenseMatrix(numUsers, numFactors); Y = new DenseMatrix(numItems, numFactors); if (initByNorm) { userBias.init(initMean, initStd); itemBias.init(initMean, initStd); W.init(initMean, initStd); Y.init(initMean, initStd); } else { userBias.init(); itemBias.init(); W.init(); Y.init(); } wlr_tc = new DenseVector(numUsers); wlr_tr = new DenseVector(numUsers); wlr_j = new DenseVector(numItems); for (int u = 0; u < numUsers; u++) { int count = socialMatrix.columnSize(u); wlr_tc.set(u, count > 0 ? 1.0 / Math.sqrt(count) : 1.0); count = socialMatrix.rowSize(u); wlr_tr.set(u, count > 0 ? 1.0 / Math.sqrt(count) : 1.0); } for (int j = 0; j < numItems; j++) { int count = trainMatrix.columnSize(j); wlr_j.set(j, count > 0 ? 1.0 / Math.sqrt(count) : 1.0); } } protected void buildModel() throws Exception { for (int iter = 1; iter <= numIters; iter++) { loss = 0; errs = 0; DenseMatrix PS = new DenseMatrix(numUsers, numFactors); DenseMatrix WS = new DenseMatrix(numUsers, numFactors); for (MatrixEntry me : trainMatrix) { int u = me.row(); // user int j = me.column(); // item double ruj = me.get(); if (ruj <= 0.0) continue; //double pred = predict(u, j); // To speed up, directly access the prediction double bu = userBias.get(u); double bj = itemBias.get(j); double pred = globalMean + bu + bj + DenseMatrix.rowMult(P, u, Q, j); // Y SparseVector uv = trainMatrix.row(u); int[] nu = uv.getIndex(); if (uv.getCount() > 0) { double sum = 0; for (int i : nu) sum += DenseMatrix.rowMult(Y, i, Q, j); pred += sum / Math.sqrt(uv.getCount()); } // W SparseVector tr = socialMatrix.row(u); int[] tu = tr.getIndex(); if (tr.getCount() > 0) { double sum = 0.0; for (int v : tu) sum += DenseMatrix.rowMult(W, v, Q, j); pred += sum / Math.sqrt(tr.getCount()); } double euj = pred - ruj; errs += euj * euj; loss += euj * euj; double w_nu = Math.sqrt(nu.length); double w_tu = Math.sqrt(tu.length); // update factors double reg_u = 1.0 / w_nu; double reg_j = wlr_j.get(j); double sgd = euj + regU * reg_u * bu; userBias.add(u, -lRate * sgd); sgd = euj + regI * reg_j * bj; itemBias.add(j, -lRate * sgd); loss += regU * reg_u * bu * bu; loss += regI * reg_j * bj * bj; double[] sum_ys = new double[numFactors]; for (int f = 0; f < numFactors; f++) { double sum = 0; for (int i : nu) sum += Y.get(i, f); sum_ys[f] = w_nu > 0 ? sum / w_nu : sum; } double[] sum_ts = new double[numFactors]; for (int f = 0; f < numFactors; f++) { double sum = 0; for (int v : tu) sum += W.get(v, f); sum_ts[f] = w_tu > 0 ? sum / w_tu : sum; } for (int f = 0; f < numFactors; f++) { double puf = P.get(u, f); double qjf = Q.get(j, f); double delta_u = euj * qjf + regU * reg_u * puf; double delta_j = euj * (puf + sum_ys[f] + sum_ts[f]) + regI * reg_j * qjf; PS.add(u, f, delta_u); Q.add(j, f, -lRate * delta_j); loss += regU * reg_u * puf * puf + regI * reg_j * qjf * qjf; for (int i : nu) { double yif = Y.get(i, f); double reg_yi = wlr_j.get(i); double delta_y = euj * qjf / w_nu + regI * reg_yi * yif; Y.add(i, f, -lRate * delta_y); loss += regI * reg_yi * yif * yif; } // update wvf for (int v : tu) { double tvf = W.get(v, f); double reg_v = wlr_tc.get(v); double delta_t = euj * qjf / w_tu + regU * reg_v * tvf; WS.add(v, f, delta_t); loss += regU * reg_v * tvf * tvf; } } } for (MatrixEntry me : socialMatrix) { int u = me.row(); int v = me.column(); double tuv = me.get(); if (tuv <= 0) continue; double pred = DenseMatrix.rowMult(P, u, W, v); double eut = pred - tuv; loss += regS * eut * eut; double csgd = regS * eut; double reg_u = wlr_tr.get(u); for (int f = 0; f < numFactors; f++) { double puf = P.get(u, f); double wvf = W.get(v, f); PS.add(u, f, csgd * wvf + regS * reg_u * puf); WS.add(v, f, csgd * puf); loss += regS * reg_u * puf * puf; } } P = P.add(PS.scale(-lRate)); W = W.add(WS.scale(-lRate)); errs *= 0.5; loss *= 0.5; if (isConverged(iter)) break; }// end of training } @Override protected double predict(int u, int j) { double pred = globalMean + userBias.get(u) + itemBias.get(j) + DenseMatrix.rowMult(P, u, Q, j); // Y SparseVector uv = trainMatrix.row(u); if (uv.getCount() > 0) { double sum = 0; for (int i : uv.getIndex()) sum += DenseMatrix.rowMult(Y, i, Q, j); pred += sum / Math.sqrt(uv.getCount()); } // W SparseVector tr = socialMatrix.row(u); if (tr.getCount() > 0) { double sum = 0.0; for (int v : tr.getIndex()) sum += DenseMatrix.rowMult(W, v, Q, j); pred += sum / Math.sqrt(tr.getCount()); } return pred; } }