/**
* Copyright 2012 plista GmbH (http://www.plista.com/)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and limitations under the License.
*/
package org.plista.kornakapi.core.training;
import com.google.common.collect.Lists;
import com.google.common.io.Closeables;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.apache.mahout.cf.taste.impl.recommender.GenericItemBasedRecommender;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.recommender.ItemBasedRecommender;
import org.apache.mahout.cf.taste.recommender.RecommendedItem;
import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
import org.plista.kornakapi.core.config.ItembasedRecommenderConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
/** a multi-threaded trainer for item kNN recommenders */
public class MultithreadedItembasedInMemoryTrainer extends AbstractTrainer {
private final ItembasedRecommenderConfig conf;
private static final Logger log = LoggerFactory.getLogger(MultithreadedItembasedInMemoryTrainer.class);
public MultithreadedItembasedInMemoryTrainer(ItembasedRecommenderConfig conf) {
super(conf);
this.conf = conf;
}
@Override
protected void doTrain(File targetFile, DataModel inmemoryData, int numProcessors) throws IOException {
BufferedWriter writer = null;
ExecutorService executorService = Executors.newFixedThreadPool(numProcessors + 1);
try {
ItemSimilarity similarity = (ItemSimilarity) Class.forName(conf.getSimilarityClass())
.getConstructor(DataModel.class).newInstance(inmemoryData);
ItemBasedRecommender trainer = new GenericItemBasedRecommender(inmemoryData, similarity);
writer = new BufferedWriter(new FileWriter(targetFile));
int batchSize = 100;
int numItems = inmemoryData.getNumItems();
List<long[]> itemIDBatches = queueItemIDsInBatches(inmemoryData.getItemIDs(), numItems, batchSize);
log.info("Queued {} items in {} batches", numItems, itemIDBatches.size());
BlockingQueue<long[]> itemsIDsToProcess = new LinkedBlockingQueue<long[]>(itemIDBatches);
BlockingQueue<String> output = new LinkedBlockingQueue<String>();
AtomicInteger numActiveWorkers = new AtomicInteger(numProcessors);
for (int n = 0; n < numProcessors; n++) {
executorService.execute(new SimilarItemsWorker(n, itemsIDsToProcess, output, trainer,
conf.getSimilarItemsPerItem(), numActiveWorkers));
}
executorService.execute(new OutputWriter(output, writer, numActiveWorkers));
} catch (Exception e) {
throw new IOException(e);
} finally {
executorService.shutdown();
try {
executorService.awaitTermination(6, TimeUnit.HOURS);
} catch (InterruptedException e) {
}
Closeables.closeQuietly(writer);
}
}
private List<long[]> queueItemIDsInBatches(LongPrimitiveIterator itemIDs, int numItems, int batchSize) {
List<long[]> itemIDBatches = Lists.newArrayListWithCapacity(numItems / batchSize);
long[] batch = new long[batchSize];
int pos = 0;
while (itemIDs.hasNext()) {
if (pos == batchSize) {
itemIDBatches.add(batch.clone());
pos = 0;
}
batch[pos] = itemIDs.nextLong();
pos++;
}
int nonQueuedItemIDs = batchSize - pos;
if (nonQueuedItemIDs > 0) {
long[] lastBatch = new long[nonQueuedItemIDs];
System.arraycopy(batch, 0, lastBatch, 0, nonQueuedItemIDs);
itemIDBatches.add(lastBatch);
}
return itemIDBatches;
}
static class OutputWriter implements Runnable {
private final BlockingQueue<String> output;
private final BufferedWriter writer;
private final AtomicInteger numActiveWorkers;
OutputWriter(BlockingQueue<String> output, BufferedWriter writer, AtomicInteger numActiveWorkers) {
this.output = output;
this.writer = writer;
this.numActiveWorkers = numActiveWorkers;
}
@Override
public void run() {
while (numActiveWorkers.get() != 0) {
try {
String lines = output.poll(10, TimeUnit.MILLISECONDS);
if (null != lines) {
writer.write(lines);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
}
static class SimilarItemsWorker implements Runnable {
private final int number;
private final BlockingQueue<long[]> itemIDBatches;
private final BlockingQueue<String> output;
private final ItemBasedRecommender trainer;
private final int howMany;
private final AtomicInteger numActiveWorkers;
SimilarItemsWorker(int number, BlockingQueue<long[]> itemIDBatches, BlockingQueue<String> output,
ItemBasedRecommender trainer, int howMany, AtomicInteger numActiveWorkers) {
this.number = number;
this.itemIDBatches = itemIDBatches;
this.output = output;
this.trainer = trainer;
this.howMany = howMany;
this.numActiveWorkers = numActiveWorkers;
}
@Override
public void run() {
int numBatchesProcessed = 0;
while (!itemIDBatches.isEmpty()) {
try {
long[] itemIDBatch = itemIDBatches.take();
StringBuilder lines = new StringBuilder();
for (long itemID : itemIDBatch) {
Iterable<RecommendedItem> similarItems = trainer.mostSimilarItems(itemID, howMany);
for (RecommendedItem similarItem : similarItems) {
lines.append(itemID).append(',').append(similarItem.getItemID())
.append(',').append(similarItem.getValue()).append('\n');
}
}
output.offer(lines.toString());
if (++numBatchesProcessed % 5 == 0) {
log.info("worker {} processed {} batches", number, numBatchesProcessed);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
log.info("worker {} processed {} batches. done.", number, numBatchesProcessed);
numActiveWorkers.decrementAndGet();
}
}
}