/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.codelibs.elasticsearch.taste.recommender.svd;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import org.codelibs.elasticsearch.taste.common.FastIDSet;
import org.codelibs.elasticsearch.taste.common.RefreshHelper;
import org.codelibs.elasticsearch.taste.common.Refreshable;
import org.codelibs.elasticsearch.taste.exception.TasteException;
import org.codelibs.elasticsearch.taste.model.DataModel;
import org.codelibs.elasticsearch.taste.model.PreferenceArray;
import org.codelibs.elasticsearch.taste.recommender.AbstractRecommender;
import org.codelibs.elasticsearch.taste.recommender.CandidateItemsStrategy;
import org.codelibs.elasticsearch.taste.recommender.IDRescorer;
import org.codelibs.elasticsearch.taste.recommender.RecommendedItem;
import org.codelibs.elasticsearch.taste.recommender.TopItems;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.base.Preconditions;
/**
* A {@link org.codelibs.elasticsearch.taste.recommender.Recommender} that uses matrix factorization (a projection of users
* and items onto a feature space)
*/
public final class SVDRecommender extends AbstractRecommender {
private Factorization factorization;
private final Factorizer factorizer;
private final PersistenceStrategy persistenceStrategy;
private final RefreshHelper refreshHelper;
private static final Logger log = LoggerFactory
.getLogger(SVDRecommender.class);
public SVDRecommender(final DataModel dataModel, final Factorizer factorizer) {
this(dataModel, factorizer, getDefaultCandidateItemsStrategy(),
getDefaultPersistenceStrategy());
}
public SVDRecommender(final DataModel dataModel,
final Factorizer factorizer,
final CandidateItemsStrategy candidateItemsStrategy) {
this(dataModel, factorizer, candidateItemsStrategy,
getDefaultPersistenceStrategy());
}
/**
* Create an SVDRecommender using a persistent store to cache factorizations. A factorization is loaded from the
* store if present, otherwise a new factorization is computed and saved in the store.
*
* The {@link #refresh(java.util.Collection) refresh} method recomputes the factorization and overwrites the store.
*
* @param dataModel
* @param factorizer
* @param persistenceStrategy
* @throws IOException
*/
public SVDRecommender(final DataModel dataModel,
final Factorizer factorizer,
final PersistenceStrategy persistenceStrategy) {
this(dataModel, factorizer, getDefaultCandidateItemsStrategy(),
persistenceStrategy);
}
/**
* Create an SVDRecommender using a persistent store to cache factorizations. A factorization is loaded from the
* store if present, otherwise a new factorization is computed and saved in the store.
*
* The {@link #refresh(java.util.Collection) refresh} method recomputes the factorization and overwrites the store.
*
* @param dataModel
* @param factorizer
* @param candidateItemsStrategy
* @param persistenceStrategy
*
*/
public SVDRecommender(final DataModel dataModel,
final Factorizer factorizer,
final CandidateItemsStrategy candidateItemsStrategy,
final PersistenceStrategy persistenceStrategy) {
super(dataModel, candidateItemsStrategy);
this.factorizer = Preconditions.checkNotNull(factorizer);
this.persistenceStrategy = Preconditions
.checkNotNull(persistenceStrategy);
try {
factorization = persistenceStrategy.load();
} catch (final IOException e) {
throw new TasteException("Error loading factorization", e);
}
if (factorization == null) {
train();
}
refreshHelper = new RefreshHelper(() -> {
train();
return null;
});
refreshHelper.addDependency(getDataModel());
refreshHelper.addDependency(factorizer);
refreshHelper.addDependency(candidateItemsStrategy);
}
static PersistenceStrategy getDefaultPersistenceStrategy() {
return new NoPersistenceStrategy();
}
private void train() {
factorization = factorizer.factorize();
try {
persistenceStrategy.maybePersist(factorization);
} catch (final IOException e) {
throw new TasteException("Error persisting factorization", e);
}
}
@Override
public List<RecommendedItem> recommend(final long userID,
final int howMany, final IDRescorer rescorer) {
Preconditions.checkArgument(howMany >= 1, "howMany must be at least 1");
log.debug("Recommending items for user ID '{}'", userID);
final PreferenceArray preferencesFromUser = getDataModel()
.getPreferencesFromUser(userID);
final FastIDSet possibleItemIDs = getAllOtherItems(userID,
preferencesFromUser);
final List<RecommendedItem> topItems = TopItems.getTopItems(howMany,
possibleItemIDs.iterator(), rescorer, new Estimator(userID));
log.debug("Recommendations are: {}", topItems);
return topItems;
}
/**
* a preference is estimated by computing the dot-product of the user and item feature vectors
*/
@Override
public float estimatePreference(final long userID, final long itemID) {
final double[] userFeatures = factorization.getUserFeatures(userID);
final double[] itemFeatures = factorization.getItemFeatures(itemID);
double estimate = 0;
for (int feature = 0; feature < userFeatures.length; feature++) {
estimate += userFeatures[feature] * itemFeatures[feature];
}
return (float) estimate;
}
private final class Estimator implements TopItems.Estimator<Long> {
private final long theUserID;
private Estimator(final long theUserID) {
this.theUserID = theUserID;
}
@Override
public double estimate(final Long itemID) {
return estimatePreference(theUserID, itemID);
}
}
/**
* Refresh the data model and factorization.
*/
@Override
public void refresh(final Collection<Refreshable> alreadyRefreshed) {
refreshHelper.refresh(alreadyRefreshed);
}
}