/* * Seldon -- open source prediction engine * ======================================= * * Copyright 2011-2015 Seldon Technologies Ltd and Rummble Ltd (http://www.seldon.io/) * * ******************************************************************************************** * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * ******************************************************************************************** */ package io.seldon.clustering.recommender; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import org.apache.log4j.Logger; import org.apache.mahout.cf.taste.impl.common.FastByIDMap; /** * Small memory footprint user cluster store. * <ul> * <li> Stores clusters as 2 bytes, so cluster ids must be between 0 and 65536. * <li> Stores weight as a byte and assumes weight is between 0 and 1, so stores it as 1 of 256 ranges between 0 and 1 * </ul> * @author rummble * */ public class MemoryUserClusterStore implements UserClusterStore { private static Logger logger = Logger.getLogger( MemoryUserClusterStore.class.getName() ); private static final int CLUSTER_NUM_BYTES = 3; private static final int SHORT_RANGE = 32767; private static final double WEIGHT_INCR = 1f/255f; FastByIDMap<byte[]> store; // use Mahout memory efficient map (28 bytes per entry) String client; long timestamp = 0; Map<Integer,Integer> clusterGroups; boolean loaded = false; ConcurrentHashMap<Long,byte[]> transientClusters; public MemoryUserClusterStore(String client,int entries) { logger.info("MemoryUserClusterStore for "+client+" of size "+entries); this.store = new FastByIDMap<>(entries); this.client = client; this.clusterGroups = new ConcurrentHashMap<>(); this.transientClusters = new ConcurrentHashMap<>(); } /** * Store clusters for user. This method is not thread safe. * @param userId * @param clusters */ public void store(long userId,List<UserCluster> clusters) { byte[] vals = new byte[clusters.size()*CLUSTER_NUM_BYTES]; int count = 0; for(UserCluster cluster : clusters) { if (timestamp == 0) timestamp = cluster.timeStamp; int clusterId = cluster.getCluster(); double weight = cluster.getWeight(); if (clusterId > 65536) { final String message = "ClusterId is too big: " + clusterId + " for user " + userId+" for client "+client; final ClusterRecommenderException recommenderException = new ClusterRecommenderException(message); logger.error(message, recommenderException); throw recommenderException; } if (weight < 0 || weight > 1) { final String message = "Bad weight: " + weight + " for user " + userId; final ClusterRecommenderException recommenderException = new ClusterRecommenderException(message); logger.error(message, recommenderException); throw recommenderException; } clusterId = clusterId - SHORT_RANGE; byte cId1 = (byte)(clusterId & 0xff); byte cId2 = (byte) ((clusterId >> 8) & 0xff); byte w = (byte) (Math.round((weight/WEIGHT_INCR)) - 128); int index = count * 3; vals[index] = cId1; vals[index+1] = cId2; vals[index+2] = w; if (!clusterGroups.containsKey(clusterId)) clusterGroups.put(clusterId, cluster.getGroup()); count++; } if(loaded) this.transientClusters.put(userId, vals); else store.put(userId, vals); } @Override public List<UserCluster> getClusters(long userId) { List<UserCluster> clusters = new ArrayList<>(); byte[] b = store.get(userId); if (b == null) b = this.transientClusters.get(userId); if (b != null) { for(int i=0;i<b.length;i=i+CLUSTER_NUM_BYTES) { int clusterId = ((short)(((b[i+1] & 0xff) << 8) | (b[i] & 0xff))) + SHORT_RANGE; double weight = (b[i+2]+128)*WEIGHT_INCR; if (logger.isDebugEnabled()) logger.debug("ClusterId:"+clusterId+" for user "+userId); int group = 0; if (clusterGroups.containsKey(clusterId)) group = clusterGroups.get(clusterId); UserCluster cluster = new UserCluster(userId,clusterId,weight,timestamp,group); clusters.add(cluster); } } return clusters; } public static void main(String[] args) { MemoryUserClusterStore m = new MemoryUserClusterStore("",1); ArrayList<UserCluster> clusters = new ArrayList<>(); clusters.add(new UserCluster(1L,2,0.75,1,1)); clusters.add(new UserCluster(1L,14561,0.75,1,1)); clusters.add(new UserCluster(1L,62421,0.75,1,1)); System.out.println("Clusters created:"); for(UserCluster cluster : clusters) System.out.println(cluster.toString()); m.store(1L, clusters); List<UserCluster> clustersGot = m.getClusters(1L); System.out.println("Clusters returned:"); for(UserCluster cluster : clustersGot) System.out.println(cluster.toString()); } @Override public List<UserCluster> getClusters() { return null; } @Override public int getNumUsersWithClusters() { return store.size(); } @Override public long getCurrentTimestamp() { return timestamp; } @Override public boolean needsExternalCaching() { return false; } public boolean isLoaded() { return loaded; } public void setLoaded(boolean loaded) { this.loaded = loaded; } }