package happy.research.cf;
import happy.coding.io.FileIO;
import happy.coding.io.Logs;
import happy.coding.io.Strings;
import happy.coding.io.net.Gmailer;
import happy.coding.math.Randoms;
import happy.coding.math.Stats;
import happy.coding.system.Dates;
import happy.coding.system.Debug;
import happy.coding.system.Systems;
import happy.research.cf.ConfigParams.DatasetMode;
import happy.research.cf.ConfigParams.ValidateMethod;
import happy.research.utils.SimUtils;
import happy.research.utils.SimUtils.SimMethod;
import java.io.File;
import java.nio.file.FileSystems;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* @author guoguibing
*/
public abstract class AbstractCF
{
protected static String methodId = null;
protected static String current_trust_dir = null;
protected static String current_trust_name = null;
protected static int numRunMethod = 0;
protected static ConfigParams params;
protected static List<String> printSettings = new ArrayList<>();
protected static List<String> methodSettings = new ArrayList<>();
// @Train {user, {item - rating}}
protected static Map<String, Map<String, Rating>> userRatingsMap;
// @Train {item, {user - rating}}
protected static Map<String, Map<String, Rating>> itemRatingsMap;
// @Test {user, {item - rating}}
protected static Map<String, Map<String, Rating>> testUserRatingsMap;
// @Test {item, {user - rating}}
protected static Map<String, Map<String, Rating>> testItemRatingsMap;
// @Test {test ratings}
protected static List<Rating> testRatings;
// @Train {user, {trust user - trust score}}
protected static Map<String, Map<String, Double>> userTNsMap, userDNsMap, userTrustorsMap;
protected abstract void init();
protected abstract void loadDataset() throws Exception;
protected abstract void prepTestRatings();
protected abstract Performance runRecAlgorithm() throws Exception;
public void execute() throws Exception
{
init();
if (VirRatingsCF.auto) params.BATCH_RUN = true;
if (params.auto_trust_sets) params.BATCH_RUN = true;
if (params.BATCH_RUN)
{
batchExecute();
collectResults();
} else
{
singleExecute();
// collect results when all single run methods finished
}
}
private void singleExecute() throws Exception
{
/* load data from data set */
loadDataset();
/* prepare test data in different views */
prepTestRatings();
formatTestRatings();
/* execute recommendation algorithm */
Performance pf = runRecAlgorithm();
/* print out the performance */
printPerformance(pf);
}
private void formatTestRatings()
{
Logs.debug("Format test ratings ...");
/* format test ratings */
testUserRatingsMap = new HashMap<>();
testItemRatingsMap = new HashMap<>();
for (Rating r : testRatings)
{
String user = r.getUserId();
String item = r.getItemId();
Map<String, Rating> irs = null;
Map<String, Rating> urs = null;
if (testUserRatingsMap.containsKey(user)) irs = testUserRatingsMap.get(user);
else irs = new HashMap<>();
if (testItemRatingsMap.containsKey(item)) urs = testItemRatingsMap.get(item);
else urs = new HashMap<>();
irs.put(item, r);
urs.put(user, r);
testUserRatingsMap.put(user, irs);
testItemRatingsMap.put(item, urs);
}
Logs.debug("# test users = {}, items = {}, ratings = {}", new Object[] { testUserRatingsMap.size(),
testItemRatingsMap.size(), testRatings.size() });
Logs.debug("Done!");
}
private void batchExecute() throws Exception
{
if (params.AUTO_VIEWS)
{
// multiple data sets/views
int[] tasks = new int[] { 0, 1, 2, 3, 4, 5, 6 };
for (int task : tasks)
{
switch (task)
{
case 0:
params.DATASET_MODE = DatasetMode.all;
break;
case 1:
params.DATASET_MODE = DatasetMode.coldUsers;
break;
case 2:
params.DATASET_MODE = DatasetMode.heavyUsers;
break;
case 3:
params.DATASET_MODE = DatasetMode.opinUsers;
break;
case 4:
params.DATASET_MODE = DatasetMode.blackSheep;
break;
case 5:
params.DATASET_MODE = DatasetMode.contrItems;
break;
case 6:
params.DATASET_MODE = DatasetMode.nicheItems;
break;
}
if (params.AUTO_CV && (params.VALIDATE_METHOD == ValidateMethod.cross_validation))
{
String train = params.TRAIN_SET;
String test = params.TEST_SET;
for (int i = 1; i < 6; i++)
{
train = train.replaceFirst("\\d", i + "");
test = test.replaceFirst("\\d", i + "");
params.TRAIN_SET = train;
params.TEST_SET = test;
String setting = "Train.Sets = [" + train + "], Test.Sets = [" + test + "]";
printSettings.add(setting);
init();
singleExecute();
}
} else if (params.AUTO_SIMILARITY)
{
for (int i = 0; i < 10; i++)
{
params.SIMILARITY_THRESHOLD = i * 0.1;
String setting = "Similarity.threshold = " + (float) params.SIMILARITY_THRESHOLD;
printSettings.add(setting);
singleExecute();
}
} else if (params.AUTO_TRUST)
{
for (int i = 0; i < 10; i++)
{
params.TRUST_THRESHOLD = i * 0.1;
String setting = "Trust.threshold = " + (float) params.TRUST_THRESHOLD;
printSettings.add(setting);
singleExecute();
}
} else if (params.AUTO_CONFIDENCE)
{
for (int i = 0; i < 10; i++)
{
params.CONFIDENCE_THRESHOLD = i * 0.1;
String setting = "Confidence.threshold = " + (float) params.CONFIDENCE_THRESHOLD;
printSettings.add(setting);
singleExecute();
}
} else if (params.AUTO_KNN)
{
for (int i = 1; i < 11; i++)
{
params.kNN = i * 5;
String setting = "kNN = " + params.kNN;
printSettings.add(setting);
singleExecute();
}
} else if (params.AUTO_TOPN)
{
for (int i = 1; i < 11; i++)
{
params.TOP_N = i * 5;
String setting = "Top-N = " + params.TOP_N;
printSettings.add(setting);
singleExecute();
}
} else
{
singleExecute();
}
}
} else if (params.VALIDATE_METHOD == ValidateMethod.cross_validation)
{
String train = params.TRAIN_SET;
String test = params.TEST_SET;
int num = 1;
if (params.AUTO_SIMILARITY) num = 10;
else if (params.AUTO_SIGNIFICANCE) num = 21;
else if (params.AUTO_TOPN) num = 5;
for (int k = 0; k < num; k++)
{
if (num > 1)
{
if (params.AUTO_SIMILARITY)
{
params.SIMILARITY_THRESHOLD = k * 0.1;
printSettings.add("Similarity.threshold = " + (float) params.SIMILARITY_THRESHOLD);
} else if (params.AUTO_SIGNIFICANCE)
{
params.SIGNIFICANCE_THRESHOLD = k * 0.005;
printSettings.add("Significance.threshold = " + (float) params.SIGNIFICANCE_THRESHOLD);
} else if (params.AUTO_TOPN)
{
if (k == 0) params.TOP_N = 2;
else params.TOP_N = k * 5;
printSettings.add("Top.N where N = " + params.TOP_N);
}
}
if (params.AUTO_KNN)
{
if (params.SIMILARITY_METHOD == SimMethod.BS)
{
SimUtils.alpha = params.readDouble("bs.alpha");;
SimUtils.beta = params.readDouble("bs.beta");;
printSettings.add("alpha = " + SimUtils.alpha + ", beta = " + SimUtils.beta);
}
for (int i = 1; i < 11; i++)
{
params.kNN = i * 5;
Logs.debug("KNN = {}", params.kNN);
for (int j = 1; j < 6; j++)
{
train = train.replaceFirst("\\d", j + "");
test = test.replaceFirst("\\d", j + "");
params.TRAIN_SET = train;
params.TEST_SET = test;
if (Debug.ON) methodSettings.add(params.readParam("itrust.probe.method"));
methodSettings.add("" + params.kNN);
init();
singleExecute();
}
}
} else if (params.AUTO_CV)
{
boolean flag = params.readParam("bs.params.batch").equalsIgnoreCase("on") ? true : false;
double alpha = params.readDouble("bs.alpha");
double beta = params.readDouble("bs.beta");
if (flag)
{
int ma = params.readInt("bs.alpha.start");
int mb = params.readInt("bs.beta.start");
for (int m = ma; m < 11; m++)
{
SimUtils.alpha = m * 0.1;
int n = 0;
if (m == ma) n = mb;
for (; n < 11; n++)
{
SimUtils.beta = n * 0.1;
if (SimUtils.alpha + SimUtils.beta > 1.0) break;
for (int i = 1; i < 6; i++)
{
train = train.replaceFirst("\\d", i + "");
test = test.replaceFirst("\\d", i + "");
params.TRAIN_SET = train;
params.TEST_SET = test;
methodSettings.add("" + SimUtils.alpha);
methodSettings.add("" + SimUtils.beta);
init();
singleExecute();
}
}
}
} else
{
SimUtils.alpha = alpha;
SimUtils.beta = beta;
for (int i = 1; i < 6; i++)
{
train = train.replaceFirst("\\d", i + "");
test = test.replaceFirst("\\d", i + "");
params.TRAIN_SET = train;
params.TEST_SET = test;
if (params.SIMILARITY_METHOD == SimMethod.BS)
{
methodSettings.add("" + SimUtils.alpha);
methodSettings.add("" + SimUtils.beta);
}
if (params.kNN > 0) methodSettings.add("" + params.kNN);
init();
singleExecute();
}
}
} else
{
singleExecute();
}
}
} else if (params.AUTO_SIMILARITY)
{
for (int i = 0; i < 10; i++)
{
params.SIMILARITY_THRESHOLD = i * 0.1;
methodSettings.add("" + params.SIMILARITY_THRESHOLD);
singleExecute();
}
} else if (params.AUTO_TRUST)
{
for (int i = 0; i < 10; i++)
{
params.TRUST_THRESHOLD = i * 0.1;
String setting = "Trust.threshold = " + (float) params.TRUST_THRESHOLD;
printSettings.add(setting);
singleExecute();
}
} else if (params.AUTO_CONFIDENCE)
{
for (int i = 0; i < 10; i++)
{
params.CONFIDENCE_THRESHOLD = i * 0.1;
String setting = "Confidence.threshold = " + (float) params.CONFIDENCE_THRESHOLD;
printSettings.add(setting);
singleExecute();
}
} else if (params.AUTO_KNN)
{
for (int i = 1; i < 21; i++)
{
params.kNN = i * 5;
String setting = "kNN = " + params.kNN;
printSettings.add(setting);
singleExecute();
}
} else if (params.AUTO_TOPN)
{
for (int i = 1; i < 11; i++)
{
params.TOP_N = i * 10;
String setting = "Top.N = " + params.TOP_N;
printSettings.add(setting);
singleExecute();
}
} else if (params.AUTO_SIGMA)
{
int num = 40;
for (int i = 0; i < num + 1; i++)
{
params.X_SIGMA = i * 0.1;
String setting = "x.sigma = " + (float) params.X_SIGMA;
printSettings.add(setting);
singleExecute();
}
} else if (VirRatingsCF.auto)
{
int num = 10;
int[] users = Randoms.indexs(27, 1357, 1384);
Logs.debug(Strings.toString(users));
for (int i = 0; i < num; i++)
{
if (i > 0)
{
List<Integer> ids = new ArrayList<>();
for (int j = 0; j < i * 3; j++)
ids.add(users[j]);
VirRatingsCF.userIds = ids;
}
String setting = "[" + i + "] users ids size = 0"
+ (VirRatingsCF.userIds == null ? "" : VirRatingsCF.userIds.size());
printSettings.add(setting);
init();
singleExecute();
}
} else if (params.auto_trust_sets)
{
String dirPath = Dataset.DIRECTORY + "Trust";
File[] dirs = new File(dirPath).listFiles();
for (File dir : dirs)
{
String name = dir.getName();
methodSettings.add(name);
current_trust_name = name;
current_trust_dir = dir.getPath() + Systems.FILE_SEPARATOR;
init();
singleExecute();
}
}
}
@SuppressWarnings("unchecked")
protected void printPerformance(Performance pf)
{
if (pf == null) return;
String f6 = "%1.6f";
String d4 = "%4d";
String f2 = "%2.2f";
String format = null;
String results = Dataset.LABEL + "," + pf.getMethod() + "," + params.DATASET_MODE; // + "," + params.kNN;
if (printSettings.size() > 0)
{
for (String setting : printSettings)
Logs.info(setting);
Logs.info(null);
printSettings.clear();
}
if (methodSettings.size() > 0)
{
for (String set : methodSettings)
results += "," + set;
methodSettings.clear();
}
int topN = params.TOP_N;
if (topN <= 0)
{
/* predictive performance */
Measures ms = pf.prediction(testUserRatingsMap);
format = "MAE = " + f6 + ", RC[" + d4 + "/" + d4 + "] = " + f2 + "%%, RMSE = " + f6;
String mae = String.format(format, new Object[] { ms.getMAE(), ms.getCoveredRatings(),
ms.getTotalRatings(), ms.getRC() * 100, ms.getRMSE() });
Logs.debug(mae);
format = "MAUE = " + f6 + ", UC[" + d4 + "/" + d4 + "] = " + f2 + "%%";
String maue = String.format(format, new Object[] { ms.getMAUE(), ms.getCoveredUsers(), ms.getTotalUsers(),
ms.getUC() * 100 });
Logs.debug(maue);
Logs.debug(null);
results += "," + ms.getRMSE() + "," + ms.getMAE() + "," + ms.getRC();
if (Debug.OFF)
{
double nMAE = 1 - ms.getMAE() / (Dataset.maxScale - Dataset.minScale);
double F1 = Stats.hMean(nMAE, ms.getRC());
results += "," + F1;
}
} else
{
boolean sort_by_prediction = methodId.startsWith("Merge") ? false : true;
if (Debug.ON)
{
sort_by_prediction = true;
/* ranking performance */
Measures ms = pf.ranking(testUserRatingsMap, sort_by_prediction);
format = "Measures@%d: Precision = " + f6 + ", Recall = " + f6 + ", F1 = " + f6 + ", MAP = " + f6
+ ", MRR = " + f6 + ", NDCG = " + f6;
int cutoff = 10;
String print = String.format(
format,
new Object[] { cutoff, ms.getPrecision(cutoff), ms.getRecall(cutoff), ms.getF1(cutoff),
ms.getMAP(cutoff), ms.getMRR(cutoff), ms.getNDCG(cutoff) });
Logs.info(print);
Logs.debug(null);
results += "\n";
for (int n : Performance.cutoffs)
{
results += (float) ms.getPrecision(n) + "," + (float) ms.getRecall(n) + "," + (float) ms.getF1(n)
+ "," + (float) ms.getMAP(n) + "," + (float) ms.getMRR(n) + "," + (float) ms.getNDCG(n)
+ "\n";
}
}
if (Debug.OFF)
{
/* diversity performance */
Measures ms = pf.diversity(new Map[] { userRatingsMap, itemRatingsMap, testItemRatingsMap }, topN,
sort_by_prediction);
double F1 = Stats.hMean(ms.getUD(), ms.getSD());
Logs.debug("UD = {}, SD = {}, F1 = {}", new Object[] { ms.getUD(), ms.getSD(), F1 });
Logs.debug(null);
results += "," + topN + "," + ms.getUD() + "," + ms.getSD() + "," + F1;
}
}
if (params.VALIDATE_METHOD == ValidateMethod.cross_validation) results += "," + params.TEST_SET;
Logs.info(results);
}
public static void collectResults() throws Exception
{
String mode = params.BATCH_RUN ? DatasetMode.batch.label : params.DATASET_MODE.label;
String program = methodId + " [" + mode + "]";
/* Collect result files to specific directory */
Path source = FileSystems.getDefault().getPath("results.txt");
Path target = FileSystems.getDefault().getPath(
FileIO.makeDirectory(AbstractCF.params.RESULTS_DIRECTORY + Dataset.LABEL) + program + "@"
+ Dates.now() + ".txt");
Files.copy(source, target);
if (params.BATCH_RUN && params.numRunMethod > 1 && numRunMethod < params.numRunMethod) FileIO.empty(source
.toString());
/* Send email to notify results */
if (params.EMAIL_NOTIFICATION)
{
String text = FileIO.readAsString(source.toString());
Gmailer notifier = new Gmailer();
//notifier.getProps().setProperty("mail.to", "fanghui1986@gmail.com");
//notifier.getProps().setProperty("mail.to", "guoguibing@gmail.com");
notifier.getProps().setProperty("mail.to", "gguo1@e.ntu.edu.sg");
//notifier.getProps().setProperty("mail.bcc", "gguo1@e.ntu.edu.sg");
notifier.getProps()
.setProperty("mail.subject", Dataset.LABEL + ": " + program + " From " + Systems.getIP());
notifier.send(text, target.toString());
}
}
}