// Copyright (C) 2014 Guibing Guo
//
// This file is part of LibRec.
//
// LibRec is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// LibRec is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with LibRec. If not, see <http://www.gnu.org/licenses/>.
//
package librec.undefined;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.ObjectOutputStream;
import java.nio.charset.Charset;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
public class TimeSVD {
public double LRATE_UF_INITIAL = 0.007;
public double LRATE_MF_INITIAL = 0.007;
public double LRATE_MW_INITIAL = 1e-3;
public double LRATE_UB_INITIAL = 3e-3;
public double LRATE_MB_INITIAL = 2e-3;
public double LRATE_UDB_INITIAL = 2.5e-3;
public double LRATE_UBA_INITIAL = 1e-5;
public double LRATE_MBB_INITIAL = 5e-5;
public double LRATE_MBS_INITIAL = 5.108e-3;
public double LRATE_MBDS_INITIAL = 2e-3;
public double LRATE_UFA_INITIAL = 1e-5;
public double LRATE_MFB_INITIAL = 2e-3;
public double K_UB = 0.01;
public double K_MB = 3.454e-2;
public double K_UDB = 5e-3;
public double K_UBA = 50;
public double K_MBB = 0.1;
public double K_MBS = 0.01;
public double K_MBDS = 1.56e-3;
public double K_UFA = 50;
public double K_MFB = 1e-3;
public double K_UF = 8.223e-2;
public double K_MF = 8.610e-3;
public double K_MW = 0.05;
public int NUM_MOVIE_BINS = 30;
public int MOVIE_BIN_SIZE = NUM_DATES / NUM_MOVIE_BINS + 1;
public double BETA = 0.4;
public double a = 6.76;
public double log_a = Math.log(a);
public final int HIGHEST_DAY_FREQ = 2651;
public final int NUM_A_TIERS = 10;
public Random rand = new Random(0);
// Current lrates
public double LRATE_UF;
public double LRATE_MF;
public double LRATE_MW;
public double LRATE_UB;
public double LRATE_MB;
public double LRATE_UDB;
public double LRATE_UBA;
public double LRATE_MBB;
public double LRATE_MBS;
public double LRATE_MBDS;
public double LRATE_UFA;
public double LRATE_MFB;
// Learning related stuff
public double[][] userFeatures;
public double[][] movieFeatures;
public double[] userBias;
public double[] movieBias;
public double[][] mw;
public double[][] sum_mw;
public double[] userRatingCount;
public double[] norm;
public double[][] movieBiasBins;
public double[] userBiasAlpha;
public double[] dateMean;
public ArrayList<Double>[] userDateBias; // Holds dates weights
public ArrayList<Integer>[] userDateBiasIndex; // Holds dates
public double[] movieBiasScale;
public ArrayList<Double>[] movieBiasDateScale;
public ArrayList<Integer>[] movieBiasDateScaleIndex;
public double[][] userFeaturesAlpha;
public double[][] movieFrequencyBias;
public static final int NUM_EPOCHS_SPAN = 5;
public static final double MIN_ERROR_DIFF = 0.00001;
public double GLOBAL_MEAN = 3.6033;
public static final double CINEMATCH_BASELINE = 0.9514;
public static final int NUM_USERS = 458293;
public static final int NUM_MOVIES = 17770;
public static final int NUM_DATES = 2243;
public static final int NUM_POINTS = 102416306;
public static final int NUM_1_POINTS = 94362233;
public static final int NUM_2_POINTS = 1965045;
public static final int NUM_3_POINTS = 1964391;
public static final int NUM_4_POINTS = 1374739;
public static final int NUM_5_POINTS = 2749898;
public static final int NUM_TRAINING_POINTS = NUM_1_POINTS + NUM_2_POINTS
+ NUM_3_POINTS;
public static final int NUM_TRAINING_PROBE_POINTS = NUM_1_POINTS
+ NUM_2_POINTS + NUM_3_POINTS + NUM_4_POINTS;
public static final String INPUT_DATA = "all.dta";
public static final String INPUT_INDEX = "all.idx";
public static final String INPUT_QUAL = "qual.dta";
public static final String LOGFILE = "log.txt";
public int NUM_FEATURES;
public int TEST_PARAM;
// Data that is stored in memory
public int[] users = new int[NUM_TRAINING_PROBE_POINTS];//userIDs
public int[] userIndex = new int[NUM_USERS];
public short[] movies = new short[NUM_TRAINING_PROBE_POINTS];
public short[] dates = new short[NUM_TRAINING_PROBE_POINTS];
public byte[] ratings = new byte[NUM_TRAINING_PROBE_POINTS];
public int[][] probeData = new int[NUM_4_POINTS][4];
public int[][] qualData = new int[NUM_5_POINTS][3];
public TimeSVD(int numFeatures) {
this.NUM_FEATURES = numFeatures;
// Initialize things that are specific to the training session.
initializeVars();
}
public TimeSVD(int numFeatures, int testParam) {
this.NUM_FEATURES = numFeatures;
// Initialize things that are specific to the training session.
initializeVars();
}
public TimeSVD(int numFeatures, double LRATE_BIAS, double LRATE_FEATURES,
double LRATE_MW, double K_BIAS, double K_FEATURES, double K_MW) {
// Set the constants to the specified values.
this.NUM_FEATURES = numFeatures;
this.LRATE_MW_INITIAL = LRATE_MW;
// Initialize things that are specific to the training session.
initializeVars();
}
@SuppressWarnings("unchecked")
private void initializeVars() {
userFeatures = new double[NUM_USERS][NUM_FEATURES];
movieFeatures = new double[NUM_MOVIES][NUM_FEATURES];
userBias = new double[NUM_USERS];
movieBias = new double[NUM_MOVIES];
mw = new double[NUM_USERS][NUM_FEATURES];
sum_mw = new double[NUM_USERS][NUM_FEATURES];
userRatingCount = new double[NUM_USERS];
norm = new double[NUM_USERS];
movieBiasBins = new double[NUM_MOVIES][NUM_MOVIE_BINS];
userBiasAlpha = new double[NUM_USERS];
dateMean = new double[NUM_USERS];
userDateBias = new ArrayList[NUM_USERS];
userDateBiasIndex = new ArrayList[NUM_USERS];
movieBiasScale = new double[NUM_USERS];
movieBiasDateScale = new ArrayList[NUM_USERS];
movieBiasDateScaleIndex = new ArrayList[NUM_USERS];
userFeaturesAlpha = new double[NUM_USERS][NUM_FEATURES];
movieFrequencyBias = new double[NUM_MOVIES][NUM_A_TIERS];
LRATE_UF = LRATE_UF_INITIAL;
LRATE_MF = LRATE_MF_INITIAL;
LRATE_MW = LRATE_MW_INITIAL;
LRATE_UB = LRATE_UB_INITIAL;
LRATE_MB = LRATE_MB_INITIAL;
LRATE_UDB = LRATE_UDB_INITIAL;
LRATE_UBA = LRATE_UBA_INITIAL;
LRATE_MBB = LRATE_MBB_INITIAL;
LRATE_MBS = LRATE_MBS_INITIAL;
LRATE_MBDS = LRATE_MBDS_INITIAL;
LRATE_UFA = LRATE_UFA_INITIAL;
LRATE_MFB = LRATE_MFB_INITIAL;
rand = new Random(0);
// Initialize weights.
for (int i = 0; i < userFeatures.length; i++) {
for (int j = 0; j < userFeatures[i].length; j++) {
userFeatures[i][j] = (rand.nextDouble() - 0.5) / 50;
}
}
for (int i = 0; i < movieFeatures.length; i++) {
for (int j = 0; j < movieFeatures[i].length; j++) {
movieFeatures[i][j] = (rand.nextDouble() - 0.5) / 50;
}
}
// User bias (specific to day)
for (int i = 0; i < userDateBias.length; i++) {
userDateBias[i] = new ArrayList<Double>();
}
for (int i = 0; i < userDateBias.length; i++) {
userDateBiasIndex[i] = new ArrayList<Integer>();
}
// Movie bias (specific to day)
for (int i = 0; i < userDateBias.length; i++) {
movieBiasDateScale[i] = new ArrayList<Double>();
}
for (int i = 0; i < userDateBias.length; i++) {
movieBiasDateScaleIndex[i] = new ArrayList<Integer>();
}
// C should be around 1
for (int i = 0; i < movieBiasScale.length; i++) {
movieBiasScale[i] = 1;
}
}
private void setVarsToNull() {
userFeatures = null;
movieFeatures = null;
userBias = null;
movieBias = null;
mw = null;
sum_mw = null;
userRatingCount = null;
norm = null;
movieBiasBins = null;
userBiasAlpha = null;
dateMean = null;
userDateBias = null;
userDateBiasIndex = null;
movieBiasScale = null;
movieBiasDateScale = null;
movieBiasDateScaleIndex = null;
userFeaturesAlpha = null;
movieFrequencyBias = null;
}
public void train() throws NumberFormatException, IOException {
System.out.println(timestampLine(String.format("Training %d features.",
NUM_FEATURES)));
// Read in input
readInput();
// Set up logfile.
BufferedWriter logWriter = new BufferedWriter(new FileWriter(LOGFILE, true));
logWriter.write("\n");
// TRAIN WITH TRAINING SET ONLY (no probe)
precompute(NUM_TRAINING_POINTS);
double previousRmse = calcProbeRmse();
logRmse(logWriter, previousRmse, 0);
int numEpochsToTrain = 0;
for (int i = 1; true; i++) {
double rmse = trainWithNumPoints(NUM_TRAINING_POINTS);
logRmse(logWriter, rmse, i);
// Slow down learning rate as we're getting close to the answer.
LRATE_UF *= .9;
LRATE_MF *= .9;
LRATE_MW *= .9;
// If probe error has been going up, we should stop.
double rmseDiff = previousRmse - rmse;
if (rmseDiff < MIN_ERROR_DIFF) {
System.out
.println(timestampLine("Probe error has started"
+ " to go up significantly; memorizing number of epochs to train."));
generateProbeOutput();
numEpochsToTrain = i;
break;
}
previousRmse = rmse;
}
// TRAIN WITH PROBE.
setVarsToNull();
initializeVars();
precompute(NUM_TRAINING_PROBE_POINTS);
logEpoch(logWriter, 0);
for (int i = 1; i <= numEpochsToTrain + NUM_EPOCHS_SPAN; i++) {
// Train with training set AND probe.
trainWithNumPoints(NUM_TRAINING_PROBE_POINTS);
logEpoch(logWriter, i);
// Slow down learning rate as we're getting close to the answer.
LRATE_UF *= .9;
LRATE_MF *= .9;
LRATE_MW *= .9;
if (i == numEpochsToTrain + NUM_EPOCHS_SPAN) {
generateOutput();
}
}
saveBestParams();
logWriter.close();
System.out.println("Done!");
}
public double trainWithNumPoints(int numPoints) throws IOException {
int user;
short movie, date;
byte rating;
int prevUser = -1;
double err, uf, mf, ufa;
short m;
double[] tmp_sum = new double[this.NUM_FEATURES];
int binNum;
double dateDev, timeDev;
double bi, bi_bin, cu, udb, delta;
Double cut;
int ind;
int f_ui;
Integer freq;
Map<Integer, Integer> dateToFreq = new HashMap<Integer, Integer>();
int d;
for (int j = 0; j < numPoints; j++) {
user = users[j];
movie = movies[j];
date = dates[j];
rating = ratings[j];
binNum = date / MOVIE_BIN_SIZE;
dateDev = date - dateMean[user];
timeDev = Math.signum(dateDev) * Math.pow(Math.abs(dateDev), BETA);
// Precomputation:
// First calculate f_ui by getting day to frequency map for this user.
if (user != prevUser) {
dateToFreq = new HashMap<Integer, Integer>();
// Traverse this user's data and construct dateToFreq
for (int l = j; l < numPoints && users[l] == user; l++) {
d = dates[l];
freq = dateToFreq.get(d);
if (freq == null) {
freq = 0;
}
freq++;
dateToFreq.put(d, freq);
}
// Pre-calc for SVD++
// Reset tmp_sum
for (int k = 0; k < tmp_sum.length; k++) {
tmp_sum[k] = 0;
}
// Reset sum_mw and calculate sums
for (int k = 0; k < NUM_FEATURES; k++) {
sum_mw[user][k] = 0;
}
for (int l = j; l < numPoints && users[l] == user; l++) {
m = movies[l];
for (int k = 0; k < NUM_FEATURES; k++) {
sum_mw[user][k] += mw[m][k];
}
}
}
prevUser = user;
// Calculate the error.
err = rating - predictRating(movie, user, date, dateToFreq);
// Cache old values
bi = movieBias[movie];
bi_bin = movieBiasBins[movie][binNum];
cu = movieBiasScale[user];
ind = movieBiasDateScaleIndex[user].indexOf((int) date);
if (ind == -1) {
cut = 0.0;
} else {
cut = movieBiasDateScale[user].get(ind);
}
// Train biases.
// User bias
userBias[user] += LRATE_UB * (err - K_UB * userBias[user]);
// Long term user bias
userBiasAlpha[user] += LRATE_UBA
* (err * timeDev - K_UBA * userBiasAlpha[user]);
// Short term user bias
ind = userDateBiasIndex[user].indexOf((int) date);
if (ind == -1) {
udb = 0.0;
delta = LRATE_UDB * (err - K_UDB * udb);
userDateBiasIndex[user].add((int) date);
userDateBias[user].add(udb + delta);
} else {
udb = userDateBias[user].get(ind);
delta = LRATE_UDB * (err - K_UDB * udb);
userDateBias[user].set(ind, udb + delta);
}
// Movie bias
movieBias[movie] += LRATE_MB * (err * (cu + cut) - K_MB * bi);
// Movie bias over time
movieBiasBins[movie][binNum] += LRATE_MBB
* (err * (cu + cut) - K_MBB * bi_bin);
// Movie bias scales (plus time version)
movieBiasScale[user] += LRATE_MBS
* (err * (bi + bi_bin) - K_MBS * (cu - 1));
ind = movieBiasDateScaleIndex[user].indexOf((int) date);
if (ind == -1) {
cut = 0.0;
delta = LRATE_MBDS * (err * (bi + bi_bin) - K_MBDS * cut);
movieBiasDateScaleIndex[user].add((int) date);
movieBiasDateScale[user].add(cut + delta);
} else {
cut = movieBiasDateScale[user].get(ind);
delta = LRATE_MBDS * (err * (bi + bi_bin) - K_MBDS * cut);
movieBiasDateScale[user].set(ind, cut + delta);
}
// Frequency of user rating bias for the movie
freq = dateToFreq.get((int) date);
if (freq == null) {
freq = 1;
}
f_ui = (int) (Math.log(freq) / log_a);
movieFrequencyBias[movie][f_ui] += LRATE_MFB * err - K_MFB
* movieFrequencyBias[movie][f_ui];
// Train all features.
for (int k = 0; k < NUM_FEATURES; k++) {
uf = userFeatures[user][k];
mf = movieFeatures[movie][k];
ufa = userFeaturesAlpha[user][k];
userFeatures[user][k] += LRATE_UF * (err * mf - K_UF * uf);
movieFeatures[movie][k] += LRATE_MF
* (err * (uf + ufa * timeDev + norm[user] * sum_mw[user][k]) - K_MF
* mf);
// Update user features alpha
userFeaturesAlpha[user][k] += LRATE_UFA
* (err * mf * timeDev - K_UFA * ufa);
// Sum mw gradients, don't train yet.
tmp_sum[k] += err * norm[user] * mf;
}
// Update movie weights if we have a new user
if (j + 1 == numPoints || users[j + 1] != user) {
for (int l = j; l >= 0 && users[l] == user; l--) {
m = movies[l];
for (int k = 0; k < NUM_FEATURES; k++) {
mw[m][k] += LRATE_MW * (tmp_sum[k] - K_MW * mw[m][k]);
}
}
}
}
// Recalculate sum_mw
for (int j = 0; j < NUM_USERS; j++) {
for (int k = 0; k < NUM_FEATURES; k++) {
sum_mw[j][k] = 0;
}
}
for (int j = 0; j < numPoints; j++) {
user = users[j];
movie = movies[j];
for (int k = 0; k < NUM_FEATURES; k++) {
sum_mw[user][k] += mw[movie][k];
}
}
// Calculate probe error and return.
return calcProbeRmse();
}
@SuppressWarnings("resource")
public void precompute(int numPoints) throws NumberFormatException,
IOException {
// If we are precomputing with probe, we need to re-read the data in the
// correct order.
if (numPoints == NUM_TRAINING_PROBE_POINTS) {
// Read input into memory
InputStream fis = new FileInputStream(INPUT_DATA);
BufferedReader br = new BufferedReader(new InputStreamReader(fis,
Charset.forName("UTF-8")));
InputStream fisIdx = new FileInputStream(INPUT_INDEX);
BufferedReader brIdx = new BufferedReader(new InputStreamReader(fisIdx,
Charset.forName("UTF-8")));
// Read INPUT_INDEX
System.out.println(timestampLine("Loading data index..."));
byte[] dataIndices = new byte[NUM_POINTS];
String line;
byte index;
int lineNum = 0;
while ((line = brIdx.readLine()) != null) {
index = Byte.parseByte(line);
dataIndices[lineNum] = index;
lineNum++;
}
// Read INPUT_DATA
System.out.println(timestampLine("Loading data..."));
String[] parts;
int user;
short movie, date;
byte rating;
lineNum = 0;
int trainingDataIndex = 0;
while ((line = br.readLine()) != null) {
parts = line.split(" ");
user = Integer.parseInt(parts[0]) - 1;
movie = (short) (Short.parseShort(parts[1]) - 1);
date = (short) (Short.parseShort(parts[2]) - 1);
rating = (byte) (Byte.parseByte(parts[3]));
if (dataIndices[lineNum] == 1 || dataIndices[lineNum] == 2
|| dataIndices[lineNum] == 3 || dataIndices[lineNum] == 4) {
users[trainingDataIndex] = user;
movies[trainingDataIndex] = movie;
dates[trainingDataIndex] = date;
ratings[trainingDataIndex] = rating;
trainingDataIndex++;
}
lineNum++;
if (lineNum % 10000000 == 0) {
System.out.println(timestampLine(lineNum + " / " + NUM_POINTS));
}
}
}
// Calculate the global rating mean
long ratingSum = 0;
for (int i = 0; i < numPoints; i++) {
ratingSum += ratings[i];
}
GLOBAL_MEAN = ((double) ratingSum) / numPoints;
int prevUser = -1;
int user;
// Index the beginning of data for each user
for (int i = 0; i < numPoints; i++) {
user = users[i];
if (user != prevUser) {
userIndex[user] = i;
}
prevUser = user;
}
// Count number of ratings for each user
for (int i = 0; i < numPoints; i++) {
user = users[i];
userRatingCount[user]++;
}
// Calculate norms
for (int i = 0; i < norm.length; i++) {
if (userRatingCount[i] == 0) {
norm[i] = 1;
} else {
norm[i] = 1 / Math.sqrt(userRatingCount[i]);
}
}
// Calculate average date of user ratings.
for (int i = 0; i < numPoints; i++) {
user = users[i];
dateMean[user] += dates[i];
}
for (int i = 0; i < dateMean.length; i++) {
if (userRatingCount[i] != 0) {
dateMean[i] /= userRatingCount[i];
}
}
System.out.println(timestampLine("Finished precomputation.\n"));
}
public double predictRating(int movie, int user, int date,
Map<Integer, Integer> dateToFreq) {
int binNum = date / MOVIE_BIN_SIZE;
double dateDev = date - dateMean[user];
double timeDev = Math.signum(dateDev) * Math.pow(Math.abs(dateDev), BETA);
// User bias (specific to day)
double udb = 0;
int ind = userDateBiasIndex[user].indexOf((int) date);
if (ind != -1) {
udb = userDateBias[user].get(ind);
}
// Movie bias (specific to day)
double cut = 0;
ind = movieBiasDateScaleIndex[user].indexOf((int) date);
if (ind != -1) {
cut = movieBiasDateScale[user].get(ind);
}
// Compute function for frequency of user rating
Integer freq = dateToFreq.get((int) date);
if (freq == null) {
freq = 1;
}
int f_ui = (int) (Math.log(freq) / log_a);
// Compute ratings
double ratingSum = GLOBAL_MEAN;
// Add in biases.
// User biases
ratingSum += userBias[user];
ratingSum += userBiasAlpha[user] * timeDev;
ratingSum += udb;
// Movie biases
ratingSum += (movieBias[movie] + movieBiasBins[movie][binNum])
* (movieBiasScale[user] + cut);
ratingSum += movieFrequencyBias[movie][f_ui];
// Take dot product of feature vectors.
for (int i = 0; i < NUM_FEATURES; i++) {
ratingSum += (userFeatures[user][i] + userFeaturesAlpha[user][i]
* timeDev + sum_mw[user][i] * norm[user])
* movieFeatures[movie][i];
}
return ratingSum;
}
public String timestampLine(String logline) {
String currentDate = new SimpleDateFormat("h:mm:ss a").format(new Date());
return currentDate + ": " + logline;
}
@SuppressWarnings("unused")
private double addAndClip(double n, double addThis) {
n += addThis;
if (n > 5) {
return 5;
} else if (n < 1) {
return 1;
}
return n;
}
private double calcProbeRmse() throws IOException {
int user, prevUser = -1;
short movie, date;
byte rating;
Map<Integer, Integer> dateToFreq = new HashMap<Integer, Integer>();
Integer freq;
// Test the model in probe set.
double rmse = 0;
for (int j = 0; j < probeData.length; j++) {
user = probeData[j][0];
movie = (short) probeData[j][1];
date = (short) probeData[j][2];
rating = (byte) probeData[j][3];
if (user != prevUser) {
dateToFreq = new HashMap<Integer, Integer>();
// Traverse this user's data and construct dateToFreq
for (int l = j; l < probeData.length && probeData[l][0] == user; l++) {
freq = dateToFreq.get((int) date);
if (freq == null) {
freq = 0;
}
freq++;
dateToFreq.put((int) date, freq);
}
}
prevUser = user;
rmse += Math
.pow(rating - predictRating(movie, user, date, dateToFreq), 2);
}
rmse = Math.sqrt(rmse / NUM_4_POINTS);
return rmse;
}
private void logRmse(BufferedWriter logWriter, double rmse, int i)
throws IOException {
// Print + log some stats.
double predictedPercent = (1 - rmse / CINEMATCH_BASELINE) * 100;
String currentDate = new SimpleDateFormat("h:mm:ss a").format(new Date());
String logline = currentDate
+ String.format(": epoch %d probe RMSE %.5f (%.2f%%) ", i, rmse,
predictedPercent);
System.out.println(logline);
logWriter.write(logline + "\n");
}
private void logEpoch(BufferedWriter logWriter, int i) throws IOException {
// Print + log some stats.
String currentDate = new SimpleDateFormat("h:mm:ss a").format(new Date());
String logline = currentDate + String.format(": epoch %d", i);
System.out.println(logline);
logWriter.write(logline + "\n");
}
// Reads input with 1 2 3 data, and then appends probe onto the end.
@SuppressWarnings("resource")
private void readInput() throws NumberFormatException, IOException {
// Read input into memory
InputStream fis = new FileInputStream(INPUT_DATA);
BufferedReader br = new BufferedReader(new InputStreamReader(fis,
Charset.forName("UTF-8")));
InputStream fisIdx = new FileInputStream(INPUT_INDEX);
BufferedReader brIdx = new BufferedReader(new InputStreamReader(fisIdx,
Charset.forName("UTF-8")));
// Read INPUT_INDEX
System.out.println(timestampLine("Loading data index..."));
byte[] dataIndices = new byte[NUM_POINTS];
String line;
byte index;
int lineNum = 0;
while ((line = brIdx.readLine()) != null) {
index = Byte.parseByte(line);
dataIndices[lineNum] = index;
lineNum++;
}
// Read INPUT_DATA
System.out.println(timestampLine("Loading data..."));
String[] parts;
int user;
short movie, date;
byte rating;
lineNum = 0;
int trainingDataIndex = 0, probeDataIndex = 0, qualDataIndex = 0;
while ((line = br.readLine()) != null) {
parts = line.split(" ");
user = Integer.parseInt(parts[0]) - 1;
movie = (short) (Short.parseShort(parts[1]) - 1);
date = (short) (Short.parseShort(parts[2]) - 1);
rating = (byte) (Byte.parseByte(parts[3]));
if (dataIndices[lineNum] == 1 || dataIndices[lineNum] == 2
|| dataIndices[lineNum] == 3) {
users[trainingDataIndex] = user;
movies[trainingDataIndex] = movie;
dates[trainingDataIndex] = date;
ratings[trainingDataIndex] = rating;
trainingDataIndex++;
} else if (dataIndices[lineNum] == 4) {
probeData[probeDataIndex][0] = user;
probeData[probeDataIndex][1] = movie;
probeData[probeDataIndex][2] = date;
probeData[probeDataIndex][3] = rating;
probeDataIndex++;
} else if (dataIndices[lineNum] == 5) {
qualData[qualDataIndex][0] = user;
qualData[qualDataIndex][1] = movie;
qualData[qualDataIndex][2] = date;
qualDataIndex++;
}
lineNum++;
if (lineNum % 10000000 == 0) {
System.out.println(timestampLine(lineNum + " / " + NUM_POINTS));
}
}
System.out.println(timestampLine("Done loading data."));
}
private void saveBestParams() throws IOException {
// Save params
// Save bestUserFeatures
FileOutputStream fileOut = new FileOutputStream("userFeatures");
ObjectOutputStream objOut = new ObjectOutputStream(fileOut);
objOut.writeObject(userFeatures);
objOut.close();
fileOut.close();
// Save bestMovieFeatures
fileOut = new FileOutputStream("movieFeatures");
objOut = new ObjectOutputStream(fileOut);
objOut.writeObject(movieFeatures);
objOut.close();
fileOut.close();
// Save bestUserBias
fileOut = new FileOutputStream("userBias");
objOut = new ObjectOutputStream(fileOut);
objOut.writeObject(userBias);
objOut.close();
fileOut.close();
// Save bestMovieBias
fileOut = new FileOutputStream("movieBias");
objOut = new ObjectOutputStream(fileOut);
objOut.writeObject(movieBias);
objOut.close();
fileOut.close();
// Save best_mw
fileOut = new FileOutputStream("mw");
objOut = new ObjectOutputStream(fileOut);
objOut.writeObject(mw);
objOut.close();
fileOut.close();
// Save best_sum_mw
fileOut = new FileOutputStream("sum_mw");
objOut = new ObjectOutputStream(fileOut);
objOut.writeObject(sum_mw);
objOut.close();
fileOut.close();
}
private void generateProbeOutput() throws IOException {
FileWriter fstream = new FileWriter("TimeSVD_123_no_probe_training");
BufferedWriter out = new BufferedWriter(fstream);
int movie, user, date, prevUser = -1;
Map<Integer, Integer> dateToFreq = new HashMap<Integer, Integer>();
Integer freq;
double predictedRating;
for (int j = 0; j < NUM_TRAINING_POINTS; j++) {
user = users[j];
movie = (short) movies[j];
date = (short) dates[j];
if (user != prevUser) {
dateToFreq = new HashMap<Integer, Integer>();
// Traverse this user's data and construct dateToFreq
for (int l = j; l < qualData.length && qualData[l][0] == user; l++) {
freq = dateToFreq.get((int) date);
if (freq == null) {
freq = 0;
}
freq++;
dateToFreq.put((int) date, freq);
}
}
prevUser = user;
predictedRating = predictRating(movie, user, date, dateToFreq);
out.write(String.format("%d %d %.4f\n", user, movie, predictedRating));
}
out.close();
prevUser = -1;
fstream = new FileWriter("TimeSVD_4_no_probe_training");
out = new BufferedWriter(fstream);
// Test the model in probe set.
for (int j = 0; j < probeData.length; j++) {
user = probeData[j][0];
movie = (short) probeData[j][1];
date = (short) probeData[j][2];
if (user != prevUser) {
dateToFreq = new HashMap<Integer, Integer>();
// Traverse this user's data and construct dateToFreq
for (int l = j; l < qualData.length && qualData[l][0] == user; l++) {
freq = dateToFreq.get((int) date);
if (freq == null) {
freq = 0;
}
freq++;
dateToFreq.put((int) date, freq);
}
}
prevUser = user;
predictedRating = predictRating(movie, user, date, dateToFreq);
out.write(String.format("%d %d %.4f\n", user, movie, predictedRating));
}
out.close();
}
private void generateOutput() throws IOException {
FileWriter fstream = new FileWriter("TimeSVD_1234_with_probe_training");
BufferedWriter out = new BufferedWriter(fstream);
int movie, user, date, prevUser = -1;
Map<Integer, Integer> dateToFreq = new HashMap<Integer, Integer>();
Integer freq;
double predictedRating;
for (int i = 0; i < NUM_TRAINING_PROBE_POINTS; i++) {
user = users[i];
movie = movies[i];
date = dates[i];
if (user != prevUser) {
dateToFreq = new HashMap<Integer, Integer>();
// Traverse this user's data and construct dateToFreq
for (int l = i; l < qualData.length && qualData[l][0] == user; l++) {
freq = dateToFreq.get((int) date);
if (freq == null) {
freq = 0;
}
freq++;
dateToFreq.put((int) date, freq);
}
}
prevUser = user;
predictedRating = predictRating(movie, user, date, dateToFreq);
out.write(String.format("%d %d %.4f\n", user, movie, predictedRating));
}
out.close();
prevUser = -1;
fstream = new FileWriter("TimeSVD_5_with_probe_training");
out = new BufferedWriter(fstream);
for (int i = 0; i < qualData.length; i++) {
user = qualData[i][0];
movie = qualData[i][1];
date = qualData[i][2];
if (user != prevUser) {
dateToFreq = new HashMap<Integer, Integer>();
// Traverse this user's data and construct dateToFreq
for (int l = i; l < qualData.length && qualData[l][0] == user; l++) {
freq = dateToFreq.get((int) date);
if (freq == null) {
freq = 0;
}
freq++;
dateToFreq.put((int) date, freq);
}
}
prevUser = user;
predictedRating = predictRating(movie, user, date, dateToFreq);
out.write(String.format("%d %d %.4f\n", user, movie, predictedRating));
}
out.close();
}
public static void main(String[] args) throws NumberFormatException,
IOException {
TimeSVD trainer;
if (args.length == 1) {
trainer = new TimeSVD(Integer.parseInt(args[0]));
} else if (args.length == 2) {
trainer = new TimeSVD(Integer.parseInt(args[0]),
Integer.parseInt(args[1]));
} else if (args.length == 7) {
trainer = new TimeSVD(Integer.parseInt(args[0]),
Double.parseDouble(args[1]), Double.parseDouble(args[2]),
Double.parseDouble(args[3]), Double.parseDouble(args[4]),
Double.parseDouble(args[5]), Double.parseDouble(args[6]));
} else {
System.exit(1);
return;
}
trainer.train();
}
}