/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.cf.taste.impl.similarity.precompute;
import com.google.common.io.Closeables;
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.recommender.ItemBasedRecommender;
import org.apache.mahout.cf.taste.recommender.RecommendedItem;
import org.apache.mahout.cf.taste.similarity.precompute.BatchItemSimilarities;
import org.apache.mahout.cf.taste.similarity.precompute.SimilarItems;
import org.apache.mahout.cf.taste.similarity.precompute.SimilarItemsWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
/**
* Precompute item similarities in parallel on a single machine. The recommender given to this class must use a
* DataModel that holds the interactions in memory (such as
* {@link org.apache.mahout.cf.taste.impl.model.GenericDataModel} or
* {@link org.apache.mahout.cf.taste.impl.model.file.FileDataModel}) as fast random access to the data is required
*/
public class MultithreadedBatchItemSimilarities extends BatchItemSimilarities {
private int batchSize;
private static final int DEFAULT_BATCH_SIZE = 100;
private static final Logger log = LoggerFactory.getLogger(MultithreadedBatchItemSimilarities.class);
/**
* @param recommender recommender to use
* @param similarItemsPerItem number of similar items to compute per item
*/
public MultithreadedBatchItemSimilarities(ItemBasedRecommender recommender, int similarItemsPerItem) {
this(recommender, similarItemsPerItem, DEFAULT_BATCH_SIZE);
}
/**
* @param recommender recommender to use
* @param similarItemsPerItem number of similar items to compute per item
* @param batchSize size of item batches sent to worker threads
*/
public MultithreadedBatchItemSimilarities(ItemBasedRecommender recommender, int similarItemsPerItem, int batchSize) {
super(recommender, similarItemsPerItem);
this.batchSize = batchSize;
}
@Override
public int computeItemSimilarities(int degreeOfParallelism, int maxDurationInHours, SimilarItemsWriter writer)
throws IOException {
ExecutorService executorService = Executors.newFixedThreadPool(degreeOfParallelism + 1);
Output output = null;
try {
writer.open();
DataModel dataModel = getRecommender().getDataModel();
BlockingQueue<long[]> itemsIDsInBatches = queueItemIDsInBatches(dataModel, batchSize, degreeOfParallelism);
BlockingQueue<List<SimilarItems>> results = new LinkedBlockingQueue<>();
AtomicInteger numActiveWorkers = new AtomicInteger(degreeOfParallelism);
for (int n = 0; n < degreeOfParallelism; n++) {
executorService.execute(new SimilarItemsWorker(n, itemsIDsInBatches, results, numActiveWorkers));
}
output = new Output(results, writer, numActiveWorkers);
executorService.execute(output);
} catch (Exception e) {
throw new IOException(e);
} finally {
executorService.shutdown();
try {
boolean succeeded = executorService.awaitTermination(maxDurationInHours, TimeUnit.HOURS);
if (!succeeded) {
throw new RuntimeException("Unable to complete the computation in " + maxDurationInHours + " hours!");
}
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
Closeables.close(writer, false);
}
return output.getNumSimilaritiesProcessed();
}
private static BlockingQueue<long[]> queueItemIDsInBatches(DataModel dataModel, int batchSize,
int degreeOfParallelism)
throws TasteException {
LongPrimitiveIterator itemIDs = dataModel.getItemIDs();
int numItems = dataModel.getNumItems();
BlockingQueue<long[]> itemIDBatches = new LinkedBlockingQueue<>((numItems / batchSize) + 1);
long[] batch = new long[batchSize];
int pos = 0;
while (itemIDs.hasNext()) {
batch[pos] = itemIDs.nextLong();
pos++;
if (pos == batchSize) {
itemIDBatches.add(batch.clone());
pos = 0;
}
}
if (pos > 0) {
long[] lastBatch = new long[pos];
System.arraycopy(batch, 0, lastBatch, 0, pos);
itemIDBatches.add(lastBatch);
}
if (itemIDBatches.size() < degreeOfParallelism) {
throw new IllegalStateException("Degree of parallelism [" + degreeOfParallelism + "] " +
" is larger than number of batches [" + itemIDBatches.size() +"].");
}
log.info("Queued {} items in {} batches", numItems, itemIDBatches.size());
return itemIDBatches;
}
private static class Output implements Runnable {
private final BlockingQueue<List<SimilarItems>> results;
private final SimilarItemsWriter writer;
private final AtomicInteger numActiveWorkers;
private int numSimilaritiesProcessed = 0;
Output(BlockingQueue<List<SimilarItems>> results, SimilarItemsWriter writer, AtomicInteger numActiveWorkers) {
this.results = results;
this.writer = writer;
this.numActiveWorkers = numActiveWorkers;
}
private int getNumSimilaritiesProcessed() {
return numSimilaritiesProcessed;
}
@Override
public void run() {
while (numActiveWorkers.get() != 0 || !results.isEmpty()) {
try {
List<SimilarItems> similarItemsOfABatch = results.poll(10, TimeUnit.MILLISECONDS);
if (similarItemsOfABatch != null) {
for (SimilarItems similarItems : similarItemsOfABatch) {
writer.add(similarItems);
numSimilaritiesProcessed += similarItems.numSimilarItems();
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
}
private class SimilarItemsWorker implements Runnable {
private final int number;
private final BlockingQueue<long[]> itemIDBatches;
private final BlockingQueue<List<SimilarItems>> results;
private final AtomicInteger numActiveWorkers;
SimilarItemsWorker(int number, BlockingQueue<long[]> itemIDBatches, BlockingQueue<List<SimilarItems>> results,
AtomicInteger numActiveWorkers) {
this.number = number;
this.itemIDBatches = itemIDBatches;
this.results = results;
this.numActiveWorkers = numActiveWorkers;
}
@Override
public void run() {
int numBatchesProcessed = 0;
while (!itemIDBatches.isEmpty()) {
try {
long[] itemIDBatch = itemIDBatches.take();
List<SimilarItems> similarItemsOfBatch = new ArrayList<>(itemIDBatch.length);
for (long itemID : itemIDBatch) {
List<RecommendedItem> similarItems = getRecommender().mostSimilarItems(itemID, getSimilarItemsPerItem());
similarItemsOfBatch.add(new SimilarItems(itemID, similarItems));
}
results.offer(similarItemsOfBatch);
if (++numBatchesProcessed % 5 == 0) {
log.info("worker {} processed {} batches", number, numBatchesProcessed);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
log.info("worker {} processed {} batches. done.", number, numBatchesProcessed);
numActiveWorkers.decrementAndGet();
}
}
}