/*
* Copyright (C) 2016 RankSys http://ranksys.org
*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*/
package org.ranksys.fm.learner;
import es.uam.eps.ir.ranksys.fast.index.FastItemIndex;
import es.uam.eps.ir.ranksys.fast.index.FastUserIndex;
import es.uam.eps.ir.ranksys.fast.preference.FastPreferenceData;
import org.ranksys.fm.PreferenceFM;
import org.ranksys.javafm.FM;
import org.ranksys.javafm.data.FMData;
import org.ranksys.javafm.learner.FMLearner;
import java.util.Arrays;
import java.util.Random;
/**
* Learner for PreferenceFMs.
*
* @author Saúl Vargas (Saul@VargasSandoval.es)
*/
public abstract class PreferenceFMLearner<U, I> {
private final FastUserIndex<U> users;
private final FastItemIndex<I> items;
/**
* Constructor.
*
* @param users user index
* @param items item index
*/
public PreferenceFMLearner(FastUserIndex<U> users, FastItemIndex<I> items) {
this.users = users;
this.items = items;
}
protected abstract FMLearner<FMData> getLearner();
protected abstract FMData toFMData(FastPreferenceData<U, I> preferences);
/**
* Trains an already existing (and possibly pre-trained) FM.
*
* @param fm preference FM
* @param train training data
*/
public void learn(PreferenceFM<U, I> fm, FastPreferenceData<U, I> train) {
getLearner().learn(fm.getFM(), toFMData(train));
}
/**
* Trains an already existing (and possibly pre-trained) FM.
*
* @param fm preference FM
* @param train training data
* @param test test data (for displaying error rate in test)
*/
public void learn(PreferenceFM<U, I> fm, FastPreferenceData<U, I> train, FastPreferenceData<U, I> test) {
getLearner().learn(fm.getFM(), toFMData(train), toFMData(test));
}
/**
* Creates and trains a preference FM.
*
* @param train training data
* @param K number of factors in FM
* @param sdev standard deviation for initialisation of parameters
* @return a trained preference FM
*/
public PreferenceFM<U, I> learn(FastPreferenceData<U, I> train, int K, double sdev) {
FMData fmTrain = toFMData(train);
FM fm = new FM(fmTrain.numFeatures(), K, new Random(), sdev);
getLearner().learn(fm, fmTrain);
return new PreferenceFM<>(users, items, fm);
}
/**
* Creates and trains a preference FM.
*
* @param train training data
* @param test test data (for displaying error rate in test)
* @param K number of factors in FM
* @param sdev standard deviation for initialisation of parameters
* @return a trained preference FM
*/
public PreferenceFM<U, I> learn(FastPreferenceData<U, I> train, FastPreferenceData<U, I> test, int K, double sdev) {
FMData fmTrain = toFMData(train);
FMData fmTest = toFMData(test);
FM fm = new FM(fmTrain.numFeatures(), K, new Random(), sdev);
getLearner().learn(fm, fmTrain, fmTest);
return new PreferenceFM<>(users, items, fm);
}
protected double[] vectorise(double v) {
double[] a = new double[users.numUsers() + items.numItems()];
Arrays.fill(a, v);
return a;
}
}