package org.plista.kornakapi.core.cluster;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.dbcp.BasicDataSource;
import org.apache.mahout.cf.taste.impl.common.FastIDSet;
import org.apache.mahout.math.Centroid;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.WeightedVector;
import org.apache.mahout.math.neighborhood.UpdatableSearcher;
import org.apache.mahout.math.random.WeightedThing;
import org.plista.kornakapi.core.config.StorageConfiguration;
import org.plista.kornakapi.core.storage.MySqlKMeansDataFilter;
import org.plista.kornakapi.core.storage.MySqlKMeansDataFilter.StreamingKMeansDataObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class StreamingKMeansClassifierModel {
private UpdatableSearcher centroids = null;
private double maxWeight = 0;
private double meanVolume=0;
private HashMap<Long, FastIDSet> userItemIds = null;
private FastIDSet userids = null;
private int dim=0;
private HashMap<Long, WeightedThing<Vector>> itemID2Centroid = new HashMap<Long, WeightedThing<Vector>>();
private static final Logger log = LoggerFactory.getLogger(StreamingKMeansClassifierModel.class);
private StorageConfiguration conf;
private FastIDSet allItems;
private int initialDim = 300;
private String label;
private BasicDataSource dataSource;
public StreamingKMeansClassifierModel(StorageConfiguration conf, String label, BasicDataSource dataSource){
this.conf = conf;
this.label = label;
this.dataSource = dataSource;
}
public void setData(StreamingKMeansDataObject data){
this.userItemIds = data.getUserItemIDs();
this.userids = data.getUserIDs();
this.dim = data.getDim();
this.allItems = data.getAllItems();
}
/**
* Method that updates the model if new centroids are callculated
* @param data
* @param centroids
*/
public void updateCentroids (UpdatableSearcher centroids){
this.centroids = centroids;
if (log.isInfoEnabled()) {
log.info("Computed "+centroids.size()+ " clusters \n");
}
Iterator<Vector> iter =centroids.iterator();
while(iter.hasNext()){
Centroid cent = (Centroid) iter.next();
double weight =cent.getWeight();
if(weight > maxWeight){
maxWeight = weight;
}
}
iter =centroids.iterator();
int i = 0;
while(iter.hasNext()){
Centroid cent = (Centroid) iter.next();
meanVolume += cent.getWeight()/maxWeight* cent.getNumNonZeroElements();
i++;
if (log.isInfoEnabled()) {
log.info("NormWeight= [{}], l2norm= [{}], Number of Users= [{}] Volume= [{}]",
new Object[] {cent.getWeight()/this.maxWeight, cent.norm(2),cent.getNumNonZeroElements() , (cent.getWeight()/maxWeight)* cent.getNumNonZeroElements() });
}
}
meanVolume = meanVolume/i;
this.itemID2Centroid.clear();
}
public UpdatableSearcher getCentroids(){
return this.centroids;
}
public double getMaxWeight(){
return this.maxWeight;
}
public double getMeanVolume(){
return this.meanVolume;
}
/**
* Returns the SequentialAccessSparseVector of an item id
* @param itemId
* @return RandomAccessSparseVector
* @throws IOException
*/
public SequentialAccessSparseVector createVector(long itemId) throws IOException{
SequentialAccessSparseVector itemVector = new SequentialAccessSparseVector(dim, initialDim);
int i = 0;
boolean isRated = false;
for(long userid : userids.toArray()){
FastIDSet itemIds = userItemIds.get(userid);
if(itemIds.contains(itemId)){
itemVector.set(i, 1);
isRated = true;
}
i++;
}
if(isRated){
return itemVector;
}else{
throw new IOException("Item unknown");
}
}
/**
*
* @param itemID
* @return
* @throws IOException
*/
public WeightedThing<Vector> getClossestCentroid(long itemID) throws IOException{
if(itemID2Centroid.containsKey(itemID)){
return itemID2Centroid.get(itemID);
}else{
WeightedThing<Vector> cent = centroids.searchFirst(createVector(itemID), false);
itemID2Centroid.put(itemID, cent);
return cent;
}
}
/**
* Gets data in old coordinate system
* @return
*/
public List<Centroid> getNewData(){
MySqlKMeansDataFilter extractor = new MySqlKMeansDataFilter(conf, label,dataSource);
StreamingKMeansDataObject data = extractor.getNewData(userids, dim);
try {
extractor.close();
} catch (IOException e1) {
// TODO Auto-generated catch block
e1.printStackTrace();
}
this.userItemIds = data.getUserItemIDs();
/**
* new items might only exist in the new userspace/coordinate system
* but not in the old one wich is still used in this method and remains unchanged.
* Therefore new items might not be concidered here
*/
ArrayList<Centroid> itemVectors = new ArrayList<Centroid>();
if(!this.allItems.equals(data.getAllItems())){
int i = 0;
for(Long itemID :data.getAllItems()){
if(!allItems.contains(itemID)){
try {
itemVectors.add(new Centroid(new WeightedVector(createVector(itemID), 1,i)));
i++;
} catch (IOException e) {
if (log.isInfoEnabled()) {
log.info(e.getMessage()) ;
}
}
}
}
}
if (log.isInfoEnabled()) {
log.info("Adding [{}] new Items", itemVectors.size()) ;
}
this.allItems = data.getAllItems();
return itemVectors;
}
/**
* Gets data in new coordinate system
* @return
*/
public ArrayList<Centroid> getData(){
MySqlKMeansDataFilter extractor = new MySqlKMeansDataFilter(conf, label,dataSource);
StreamingKMeansDataObject data = extractor.getData();
try {
extractor.close();
} catch (IOException e1) {
// TODO Auto-generated catch block
e1.printStackTrace();
}
this.setData(data);
ArrayList<Centroid> itemVectors = new ArrayList<Centroid>();
int n = 0;
for(long itemId : allItems.toArray()){
try {
SequentialAccessSparseVector itemVector = createVector(itemId);
itemVectors.add(new Centroid(new WeightedVector(itemVector, 1,n)));
n++;
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
if (log.isInfoEnabled()) {
log.info("Done!");
}
return itemVectors;
}
}