/**
* Copyright 2012 plista GmbH (http://www.plista.com/)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and limitations under the License.
*/
package org.plista.kornakapi.core.recommender;
import org.apache.commons.math.linear.Array2DRowRealMatrix;
import org.apache.commons.math.linear.ArrayRealVector;
import org.apache.commons.math.linear.RealVector;
import org.apache.commons.math.linear.LUDecompositionImpl;
import org.apache.commons.math.linear.RealMatrix;
import org.apache.mahout.cf.taste.common.NoSuchItemException;
import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** a matrix factorization that supports folding in new users */
public class FoldingFactorization {
private final Factorization factorization;
private final double[][] userFoldInMatrix;
private static final Logger log = LoggerFactory.getLogger(FoldingFactorization.class);
public FoldingFactorization(Factorization factorization) {
this.factorization = factorization;
userFoldInMatrix = computeUserFoldInMatrix(factorization.allItemFeatures());
}
public Factorization factorization() {
return factorization;
}
/* see http://www.slideshare.net/fullscreen/srowen/matrix-factorization/16 for details */
private double[][] computeUserFoldInMatrix(double[][] itemFeatures) {
/* if there are no items, we cannot fold in anything */
if (itemFeatures.length == 0) {
return new double[0][0];
}
if (log.isInfoEnabled()) {
log.info("Computing fold-in matrix from a {} x {} item features matrix", factorization.numItems(),
factorization.numFeatures());
}
RealMatrix Y = new Array2DRowRealMatrix(itemFeatures);
RealMatrix YTY = Y.transpose().multiply(Y);
RealMatrix YTYInverse = new LUDecompositionImpl(YTY).getSolver().getInverse();
return Y.multiply(YTYInverse).getData();
}
public double[] foldInUser(long[] itemIDs) throws NoSuchItemException {
double[] userFeatures = new double[factorization.numFeatures()];
for (long itemID : itemIDs) {
int itemIndex = -1;
try{
itemIndex = factorization.itemIndex(itemID);
}catch(NoSuchItemException e){
if (log.isInfoEnabled()) {
log.info("Item unknown: {}", itemID);
if(itemIDs.length == 1){
throw new NoSuchItemException("At least one item must be known");
}
}
}
if(itemIndex >=0){
for (int feature = 0; feature < factorization.numFeatures(); feature++) {
userFeatures[feature] += userFoldInMatrix[itemIndex][feature];
}
}
}
return userFeatures;
}
public double[] foldInAnonymousUser(long[] itemIDs) throws NoSuchItemException {
double[] userFeatures = new double[factorization.numFeatures()];
for (long itemID : itemIDs) {
try{
int itemIndex = factorization.itemIndex(itemID);
for (int feature = 0; feature < factorization.numFeatures(); feature++) {
userFeatures[feature] += factorization.allItemFeatures()[itemIndex][feature];
}
}catch(NoSuchItemException e){
if (log.isInfoEnabled()) {
log.info("Item unknown: {}", itemID);
if(itemIDs.length == 1){
throw new NoSuchItemException("At least one item must be known");
}
}
}
}
RealVector userFeaturesAsVector = new ArrayRealVector(userFeatures);
RealVector normalised = userFeaturesAsVector.mapDivide(userFeaturesAsVector.getL1Norm());
return normalised.getData();
}
}