/* * Seldon -- open source prediction engine * ======================================= * * Copyright 2011-2015 Seldon Technologies Ltd and Rummble Ltd (http://www.seldon.io/) * * ******************************************************************************************** * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * ******************************************************************************************** */ package io.seldon.clustering.recommender.jdo; import io.seldon.clustering.recommender.ClusterCountNoImplementationException; import io.seldon.clustering.recommender.ClusterCountStore; import io.seldon.db.jdo.ClientPersistable; import io.seldon.db.jdo.DatabaseException; import io.seldon.db.jdo.Transaction; import io.seldon.db.jdo.TransactionPeer; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.Map; import java.util.Set; import javax.jdo.PersistenceManager; import javax.jdo.Query; import org.apache.commons.lang.StringUtils; import org.apache.log4j.Logger; public class JdoClusterCountStore extends ClientPersistable implements ClusterCountStore { private static Logger logger = Logger.getLogger(JdoClusterCountStore.class.getName()); double alpha = 3600; AsyncClusterCountFactory clusterCountFactory; public JdoClusterCountStore(String client,AsyncClusterCountFactory clusterCountFactory) { super(client); this.clusterCountFactory = clusterCountFactory; } @Override public void add(final int clusterId, final long itemId, final double weight, long clusterTimestamp) { AsyncClusterCountStore asyncStore = clusterCountFactory.get(this.clientName); if (asyncStore != null) { asyncStore.put(new AsyncClusterCountStore.ClusterCount(clusterId,itemId,weight)); } else { final PersistenceManager pm = getPM(); try { TransactionPeer.runTransaction(new Transaction(pm) { public void process() { Query query = pm.newQuery("javax.jdo.query.SQL", "insert into cluster_counts values (?,?,?,unix_timestamp()) on duplicate key update count=?+exp(-(unix_timestamp()-t)/?)*count,t=unix_timestamp();"); ArrayList<Object> args = new ArrayList<>(); args.add(clusterId); args.add(itemId); args.add(weight); args.add(weight); args.add(alpha); query.executeWithArray(args.toArray()); } }); } catch (DatabaseException e) { logger.error("Failed to Add count", e); } } } /** * timestamp is ignore. */ @Override public void add(final int clusterId, final long itemId,final double weight,long timestamp,final long time) { AsyncClusterCountStore asyncStore = clusterCountFactory.get(this.clientName); if (asyncStore != null) { asyncStore.put(new AsyncClusterCountStore.ClusterCount(clusterId,itemId,weight)); } else { final PersistenceManager pm = getPM(); try { TransactionPeer.runTransaction(new Transaction(pm) { public void process() { Query query = pm.newQuery( "javax.jdo.query.SQL", "insert into cluster_counts values (?,?,?,unix_timestamp()) on duplicate key update count=?+exp(-(greatest(unix_timestamp()-t,0)/?))*count,t=unix_timestamp();"); ArrayList<Object> args = new ArrayList<>(); args.add(clusterId); args.add(itemId); args.add(weight); args.add(weight); args.add(alpha); query.executeWithArray(args.toArray()); }}); } catch (DatabaseException e) { logger.error("Failed to Add count", e); } } } /** * timestamp and time is ignore for db counts - the db value for these is used. They are assumed to be up-todate with clusters. */ @Override public double getCount(int clusterId, long itemId,long timestamp) { final PersistenceManager pm = getPM(); Query query = pm.newQuery( "javax.jdo.query.SQL", "select count from cluster_counts where id=? and item_id=?" ); query.setResultClass(Double.class); query.setUnique(true); Double count = (Double) query.execute(clusterId, itemId); if (count != null) return count; else return 0D; } @Override public void setAlpha(double alpha) { this.alpha = alpha; } @Override public boolean needsExternalCaching() { return true; } /** * timestamp and time is ignore for db counts - the db value for these is used. They are assumed to be up-to-date with clusters. */ @Override public Map<Long, Double> getTopCounts(int clusterId, long timestamp, int limit, double decay) { final PersistenceManager pm = getPM(); Map<Long,Double> map = new HashMap<>(); Query query = pm.newQuery( "javax.jdo.query.SQL", "select item_id,exp(-(greatest(unix_timestamp()-t,0)/?))*count as decayedCount from cluster_counts where id=? order by decayedCount desc limit "+limit ); Collection<Object[]> res = (Collection<Object[]>) query.execute(decay,clusterId); for(Object[] r : res) { Long itemId = (Long) r[0]; Double count = (Double) r[1]; map.put(itemId, count); } return map; } //TODO - need to use decay/alpha //ignore time use db time @Override public Map<Long, Double> getTopCounts(int limit, double decay) throws ClusterCountNoImplementationException { final PersistenceManager pm = getPM(); Map<Long,Double> map = new HashMap<>(); Query query = pm.newQuery( "javax.jdo.query.SQL", "select item_id,sum(exp(-(greatest(unix_timestamp()-t,0)/?))*count) as decayedSumCount from cluster_counts group by item_id order by decayedSumCount desc limit "+limit ); Collection<Object[]> res = (Collection<Object[]>) query.execute(decay); for(Object[] r : res) { Long itemId = (Long) r[0]; Double count = (Double) r[1]; map.put(itemId, count); } return map; } @Override public Map<Long, Double> getTopCountsByDimension(int clusterId, Set<Integer> dimensions, long timestamp, int limit, double decay) throws ClusterCountNoImplementationException { final PersistenceManager pm = getPM(); Map<Long,Double> map = new HashMap<>(); String dimensionsStr = StringUtils.join(dimensions, ","); Query query = pm.newQuery( "javax.jdo.query.SQL", "select item_id,exp(-(greatest(unix_timestamp()-t,0)/?))*count as decayedCount from cluster_counts natural join item_map_enum natural join dimension where id = ? and dim_id in ("+dimensionsStr+") order by decayedCount desc limit "+limit ); Collection<Object[]> res = (Collection<Object[]>) query.execute(decay,clusterId); for(Object[] r : res) { Long itemId = (Long) r[0]; Double count = (Double) r[1]; map.put(itemId, count); } return map; } @Override public Map<Long, Double> getTopCountsByTwoDimensions(int clusterId, Set<Integer> dimensions, int dimension2, long timestamp, int limit, double decay) throws ClusterCountNoImplementationException { final PersistenceManager pm = getPM(); Map<Long,Double> map = new HashMap<>(); String dimensionsStr = StringUtils.join(dimensions, ","); Query query = pm.newQuery( "javax.jdo.query.SQL", "select c.item_id,exp(-(greatest(unix_timestamp()-t,0)/?))*count as decayedCount from cluster_counts c natural join item_map_enum natural join dimension d1 join item_map_enum ime2 on (c.item_id=ime2.item_id) join dimension d2 on (d2.attr_id=ime2.attr_id and ime2.value_id=d2.value_id) where id = ? and d1.dim_id in ("+dimensionsStr+") and d2.dim_id=? order by decayedCount desc limit "+limit); Collection<Object[]> res = (Collection<Object[]>) query.execute(decay,clusterId,dimension2); for(Object[] r : res) { Long itemId = (Long) r[0]; Double count = (Double) r[1]; map.put(itemId, count); } return map; } @Override public Map<Long, Double> getTopSignificantCountsByDimension(int clusterId, Set<Integer> dimensions, long timestamp, int limit, double decay) throws ClusterCountNoImplementationException { final PersistenceManager pm = getPM(); Map<Long,Double> map = new HashMap<>(); String dimensionsStr = StringUtils.join(dimensions, ","); Query query = pm.newQuery( "javax.jdo.query.SQL", "select item_id,r.v*r.count as score from (select item_id,(count/sl-s/sg)/greatest(count/sl,s/sg) as v,count from (select exp(-(greatest(unix_timestamp()-c.t,0)/?))*c.count as count,cit.total as s,sl,cct.total as sg,c.item_id from cluster_counts c join (select sum(exp(-(greatest(unix_timestamp()-c.t,0)/?))) sl from cluster_counts c where id=?) t1 join cluster_counts_total cct join cluster_counts_item_total cit on (c.item_id=cit.item_id) where id=?) r1) r natural join item_map_enum natural join dimension where dim_id in ("+dimensionsStr+") order by score desc limit "+limit ); ArrayList<Object> args = new ArrayList<>(); args.add(decay); args.add(decay); args.add(clusterId); args.add(clusterId); Collection<Object[]> res = (Collection<Object[]>) query.executeWithArray(args.toArray()); for(Object[] r : res) { Long itemId = (Long) r[0]; Double count = (Double) r[1]; map.put(itemId, count); } return map; } @Override public Map<Long, Double> getTopCountsByDimension(Set<Integer> dimensions, int limit, double decay) throws ClusterCountNoImplementationException { final PersistenceManager pm = getPM(); Map<Long,Double> map = new HashMap<>(); String dimensionsStr = StringUtils.join(dimensions, ","); Query query = pm.newQuery( "javax.jdo.query.SQL", "select item_id,sum(exp(-(greatest(unix_timestamp()-t,0)/?))*count) as decayedSumCount from cluster_counts natural join item_map_enum natural join dimension where dim_id in ("+dimensionsStr+") group by item_id order by decayedSumCount desc limit "+limit ); Collection<Object[]> res = (Collection<Object[]>) query.execute(decay); for(Object[] r : res) { Long itemId = (Long) r[0]; Double count = (Double) r[1]; map.put(itemId, count); } return map; } @Override public Map<Long, Double> getTopCountsByTwoDimensions(Set<Integer> dimensions, int dimension2, int limit, double decay) throws ClusterCountNoImplementationException { final PersistenceManager pm = getPM(); Map<Long,Double> map = new HashMap<>(); String dimensionsStr = StringUtils.join(dimensions, ","); Query query = pm.newQuery( "javax.jdo.query.SQL", "select c.item_id,sum(exp(-(greatest(unix_timestamp()-t,0)/?))*count) as decayedCount from cluster_counts c natural join item_map_enum ime1 join dimension d1 on (d1.attr_id=ime1.attr_id and ime1.value_id=d1.value_id) join item_map_enum ime2 on (c.item_id=ime2.item_id) join dimension d2 on (d2.attr_id=ime2.attr_id and ime2.value_id=d2.value_id) where d1.dim_id in ("+dimensionsStr+") and d2.dim_id = ? group by item_id order by decayedcount desc limit "+limit ); ArrayList<Object> args = new ArrayList<>(); args.add(decay); args.add(dimension2); Collection<Object[]> res = (Collection<Object[]>) query.executeWithArray(args.toArray()); for(Object[] r : res) { Long itemId = (Long) r[0]; Double count = (Double) r[1]; map.put(itemId, count); } return map; } @Override public Map<Long, Double> getTopCountsByTagAndDimension(String tag, int tagAttrId, Set<Integer> dimensions, int limit, double decay) throws ClusterCountNoImplementationException { final PersistenceManager pm = getPM(); Map<Long,Double> map = new HashMap<>(); String dimensionsStr = StringUtils.join(dimensions, ","); Query query = pm.newQuery( "javax.jdo.query.SQL", "select cluster_counts.item_id,sum(exp(-(greatest(unix_timestamp()-t,0)/?))*count) as decayedSumCount from cluster_counts natural join item_map_enum natural join dimension join item_map_varchar on (cluster_counts.item_id=item_map_varchar.item_id and item_map_varchar.attr_id=?) where dim_id in ("+dimensionsStr+") and value regexp \"(^|,)[ ]*"+tag+"[ ]*(,|$)\" group by item_id order by decayedSumCount desc limit "+limit ); Collection<Object[]> res = (Collection<Object[]>) query.execute(decay,tagAttrId); for(Object[] r : res) { Long itemId = (Long) r[0]; Double count = (Double) r[1]; map.put(itemId, count); } logger.info("getTopCountsByTagAndTwoDimension "+tag+" tagATtrId "+tagAttrId+" dimension "+dimensionsStr+" decay "+decay+ " limit "+limit+ " results "+map.size()); return map; } @Override public Map<Long, Double> getTopCountsByTag(String tag, int tagAttrId, int limit, double decay) throws ClusterCountNoImplementationException { final PersistenceManager pm = getPM(); Map<Long,Double> map = new HashMap<>(); Query query = pm.newQuery( "javax.jdo.query.SQL", "select cluster_counts.item_id,sum(exp(-(greatest(unix_timestamp()-t,0)/?))*count) as decayedSumCount from cluster_counts join item_map_varchar on (cluster_counts.item_id=item_map_varchar.item_id and item_map_varchar.attr_id=?) where value regexp \"(^|,)[ ]*"+tag+"[ ]*(,|$)\" group by item_id order by decayedSumCount desc limit "+limit ); Collection<Object[]> res = (Collection<Object[]>) query.execute(decay,tagAttrId); for(Object[] r : res) { Long itemId = (Long) r[0]; Double count = (Double) r[1]; map.put(itemId, count); } return map; } @Override public Map<Long, Double> getTopCountsByTagAndTwoDimensions(String tag, int tagAttrId, Set<Integer> dimensions, int dimension2, int limit, double decay) throws ClusterCountNoImplementationException { final PersistenceManager pm = getPM(); Map<Long,Double> map = new HashMap<>(); String dimensionsStr = StringUtils.join(dimensions, ","); Query query = pm.newQuery( "javax.jdo.query.SQL", "select c.item_id,sum(exp(-(greatest(unix_timestamp()-t,0)/?))*count) as decayedCount from cluster_counts c natural join item_map_enum ime1 join dimension d1 on (d1.attr_id=ime1.attr_id and ime1.value_id=d1.value_id) join item_map_enum ime2 on (c.item_id=ime2.item_id) join dimension d2 on (d2.attr_id=ime2.attr_id and ime2.value_id=d2.value_id) join item_map_varchar on (c.item_id=item_map_varchar.item_id and item_map_varchar.attr_id=?) where d1.dim_id in ("+dimensionsStr+") and d2.dim_id = ? and value regexp \"(^|,)[ ]*"+tag+"[ ]*(,|$)\" group by item_id order by decayedcount desc limit "+limit ); ArrayList<Object> args = new ArrayList<>(); args.add(decay); args.add(tagAttrId); args.add(dimension2); Collection<Object[]> res = (Collection<Object[]>) query.executeWithArray(args.toArray()); for(Object[] r : res) { Long itemId = (Long) r[0]; Double count = (Double) r[1]; map.put(itemId, count); } logger.info("getTopCountsByTagAndTwoDimensions "+tag+" tagATtrId "+tagAttrId+" dimension "+dimensionsStr+" dimension2 "+dimension2+" decay "+decay+ " limit "+limit+ " results "+map.size()); return map; } }