// Copyright (C) 2014 Guibing Guo // // This file is part of LibRec. // // LibRec is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // LibRec is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with LibRec. If not, see <http://www.gnu.org/licenses/>. // package librec.undefined; import librec.data.DenseMatrix; import librec.data.DenseVector; import librec.data.MatrixEntry; import librec.data.SparseMatrix; import librec.data.SparseVector; import librec.intf.SocialRecommender; /** * Our ongoing testing algorithm * * @author guoguibing * */ public class TrustSVD_DT extends SocialRecommender { private DenseMatrix W, Y, F; private DenseVector wlr_j, wlr_tc, wlr_tr, wlr_dtc, wlr_dtr; private static double reg_dt = 0, neg = 0.05; private static SparseMatrix T, DT; static { T = socialMatrix.clone(); DT = socialMatrix.clone(); for (MatrixEntry me : T) { double trust = me.get(); if (trust < 0) me.set(0.0); } for (MatrixEntry me : DT) { double distrust = me.get(); if (distrust > 0) me.set(0.0); else me.set(-distrust); } } public TrustSVD_DT(SparseMatrix trainMatrix, SparseMatrix testMatrix, int fold) { super(trainMatrix, testMatrix, fold); if (params.containsKey("val.reg")) { double reg = RecUtils.getMKey(params, "val.reg"); regU = (float) reg; regI = (float) reg; regS = (float)reg; } else { neg = RecUtils.getMKey(params, "val.reg.neg"); reg_dt = cf.getDouble("val.reg.distrust"); regS = cf.getFloat("val.reg.social"); } } @Override protected void initModel() throws Exception { super.initModel(); userBias = new DenseVector(numUsers); itemBias = new DenseVector(numItems); W = new DenseMatrix(numUsers, numFactors); F = new DenseMatrix(numUsers, numFactors); Y = new DenseMatrix(numItems, numFactors); if (initByNorm) { userBias.init(initMean, initStd); itemBias.init(initMean, initStd); W.init(initMean, initStd); F.init(initMean, initStd); Y.init(initMean, initStd); } else { userBias.init(); itemBias.init(); W.init(); F.init(); Y.init(); } wlr_tc = new DenseVector(numUsers); wlr_tr = new DenseVector(numUsers); wlr_dtc = new DenseVector(numUsers); wlr_dtr = new DenseVector(numUsers); wlr_j = new DenseVector(numItems); for (int u = 0; u < numUsers; u++) { int count = T.columnSize(u); wlr_tc.set(u, count > 0 ? 1.0 / Math.sqrt(count) : 1.0); count = T.rowSize(u); wlr_tr.set(u, count > 0 ? 1.0 / Math.sqrt(count) : 1.0); count = DT.columnSize(u); wlr_dtc.set(u, count > 0 ? 1.0 / Math.sqrt(count) : 1.0); count = DT.rowSize(u); wlr_dtr.set(u, count > 0 ? 1.0 / Math.sqrt(count) : 1.0); } for (int j = 0; j < numItems; j++) { int count = trainMatrix.columnSize(j); wlr_j.set(j, count > 0 ? 1.0 / Math.sqrt(count) : 1.0); } } protected void buildModel() throws Exception { for (int iter = 1; iter <= numIters; iter++) { loss = 0; errs = 0; DenseMatrix PS = new DenseMatrix(numUsers, numFactors); DenseMatrix WS = new DenseMatrix(numUsers, numFactors); DenseMatrix FS = new DenseMatrix(numUsers, numFactors); for (MatrixEntry me : trainMatrix) { int u = me.row(); // user int j = me.column(); // item double ruj = me.get(); if (ruj <= 0.0) continue; double bu = userBias.get(u); double bj = itemBias.get(j); double pred = globalMean + bu + bj + DenseMatrix.rowMult(P, u, Q, j); // Y SparseVector uv = trainMatrix.row(u); int[] nu = uv.getIndex(); if (uv.getCount() > 0) { double sum = 0; for (int i : nu) sum += DenseMatrix.rowMult(Y, i, Q, j); pred += sum / Math.sqrt(uv.getCount()); } // T SparseVector tr = T.row(u); int[] tu = tr.getIndex(); if (tr.getCount() > 0) { double sum = 0.0; for (int v : tu) sum += DenseMatrix.rowMult(W, v, Q, j); pred += sum / Math.sqrt(tr.getCount()); } // DT SparseVector dtr = DT.row(u); int[] dtu = dtr.getIndex(); if (dtr.getCount() > 0) { double sum = 0.0; for (int k : dtu) sum += DenseMatrix.rowMult(F, k, Q, j); pred += sum / Math.sqrt(dtr.getCount()) * neg; } double euj = pred - ruj; errs += euj * euj; loss += euj * euj; double w_nu = Math.sqrt(nu.length); double w_tu = Math.sqrt(tu.length); double w_dtu = Math.sqrt(dtu.length); // update factors double reg_u = 1.0 / w_nu; double reg_j = wlr_j.get(j); double sgd = euj + regU * reg_u * bu; userBias.add(u, -lRate * sgd); sgd = euj + regI * reg_j * bj; itemBias.add(j, -lRate * sgd); loss += regU * reg_u * bu * bu; loss += regI * reg_j * bj * bj; double[] sum_ys = new double[numFactors]; for (int f = 0; f < numFactors; f++) { double sum = 0; for (int i : nu) sum += Y.get(i, f); sum_ys[f] = w_nu > 0 ? sum / w_nu : sum; } double[] sum_ts = new double[numFactors]; for (int f = 0; f < numFactors; f++) { double sum = 0; for (int v : tu) sum += W.get(v, f); sum_ts[f] = w_tu > 0 ? sum / w_tu : sum; } double[] sum_dts = new double[numFactors]; for (int f = 0; f < numFactors; f++) { double sum = 0; for (int v : dtu) sum += F.get(v, f); sum_dts[f] = w_dtu > 0 ? sum / w_dtu : sum; } for (int f = 0; f < numFactors; f++) { double puf = P.get(u, f); double qjf = Q.get(j, f); double delta_u = euj * qjf + regU * reg_u * puf; double delta_j = euj * (puf + sum_ys[f] + sum_ts[f] + neg * sum_dts[f]) + regI * reg_j * qjf; PS.add(u, f, delta_u); Q.add(j, f, -lRate * delta_j); loss += regU * reg_u * puf * puf + regI * reg_j * qjf * qjf; for (int i : nu) { double yif = Y.get(i, f); double reg_yi = wlr_j.get(i); double delta_y = euj * qjf / w_nu + regI * reg_yi * yif; Y.add(i, f, -lRate * delta_y); loss += regI * reg_yi * yif * yif; } // update Wvf for (int v : tu) { double tvf = W.get(v, f); double reg_v = wlr_tc.get(v); double delta_t = euj * qjf / w_tu + regU * reg_v * tvf; WS.add(v, f, delta_t); loss += regU * reg_v * tvf * tvf; } // update Fkf for (int k : dtu) { double tkf = F.get(k, f); double reg_k = wlr_dtc.get(k); double delta_t = neg * euj * qjf / w_dtu + regU * reg_k * tkf; FS.add(k, f, delta_t); loss += regU * reg_k * tkf * tkf; } } } for (MatrixEntry me : T) { int u = me.row(); int v = me.column(); double tuv = me.get(); if (tuv == 0) continue; double pred = DenseMatrix.rowMult(P, u, W, v); double euv = pred - tuv; loss += regS * euv * euv; double csgd = regS * euv; double reg_u = wlr_tr.get(u); for (int f = 0; f < numFactors; f++) { double puf = P.get(u, f); double wvf = W.get(v, f); PS.add(u, f, csgd * wvf + regS * reg_u * puf); WS.add(v, f, csgd * puf); loss += regS * reg_u * puf * puf; } } for (MatrixEntry me : DT) { int u = me.row(); int k = me.column(); double duk = me.get(); if (duk == 0) continue; double pred = DenseMatrix.rowMult(P, u, F, k); double euk = pred - duk; loss += reg_dt * euk * euk; double csgd = reg_dt * euk; double reg_u = wlr_dtr.get(u); for (int f = 0; f < numFactors; f++) { double puf = P.get(u, f); double fkf = F.get(k, f); PS.add(u, f, csgd * fkf + reg_dt * reg_u * puf); FS.add(k, f, csgd * puf); loss += reg_dt * reg_u * puf * puf; } } P = P.add(PS.scale(-lRate)); W = W.add(WS.scale(-lRate)); F = F.add(FS.scale(-lRate)); errs *= 0.5; loss *= 0.5; if (isConverged(iter)) break; }// end of training } @Override protected double predict(int u, int j) { double pred = globalMean + userBias.get(u) + itemBias.get(j) + DenseMatrix.rowMult(P, u, Q, j); // Y SparseVector uv = trainMatrix.row(u); if (uv.getCount() > 0) { double sum = 0; for (int i : uv.getIndex()) sum += DenseMatrix.rowMult(Y, i, Q, j); pred += sum / Math.sqrt(uv.getCount()); } // T SparseVector tr = T.row(u); if (tr.getCount() > 0) { double sum = 0.0; for (int v : tr.getIndex()) sum += DenseMatrix.rowMult(W, v, Q, j); pred += sum / Math.sqrt(tr.getCount()); } // DT SparseVector dtr = DT.row(u); if (dtr.getCount() > 0) { double sum = 0.0; for (int k : dtr.getIndex()) sum += DenseMatrix.rowMult(F, k, Q, j); pred += sum / Math.sqrt(dtr.getCount()) * neg; } return pred; } @Override public String toString() { return super.toString() + "," + reg_dt + "," + neg; } }