/* * Seldon -- open source prediction engine * ======================================= * * Copyright 2011-2015 Seldon Technologies Ltd and Rummble Ltd (http://www.seldon.io/) * * ******************************************************************************************** * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * ******************************************************************************************** */ package io.seldon.clustering.recommender.jdo; import io.seldon.db.jdbc.JDBCConnectionFactory; import io.seldon.general.Action; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.SQLException; import java.util.Map; import java.util.TreeMap; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import org.apache.log4j.Logger; public class AsyncClusterCountStore implements Runnable { public static class ClusterCount { public int clusterId; public long itemId; public double weight; public ClusterCount(int clusterId, long itemId, double weight) { super(); this.clusterId = clusterId; this.itemId = itemId; this.weight = weight; } } private static Logger logger = Logger.getLogger(AsyncClusterCountStore.class.getName()); private String client; private int timeout; private LinkedBlockingQueue<ClusterCount> queue; private int batchSize; // batch size for sql statements private int maxDBRetries = 1; // max # of times to try sql statement on exception boolean keepRunning; double decay = 3600; Connection connection = null; PreparedStatement countPreparedStatement; private int countsAdded = 0; // actions added so far to sql statement private int countsAddedTotal = 0; // total actions added including counts for same cluster and item (will be different than counts added for cached version that uses db time) long lastSqlRunTime = 0; int badActions = 0; boolean useDBTime = true; TreeMap<Integer,TreeMap<Long,Double>> clusterCounts; public AsyncClusterCountStore(String client, int qTimeoutSecs, int batchSize, int maxQSize,int maxDBRetries,double decay,boolean useDBTime) { this.client = client; this.batchSize = batchSize; this.maxDBRetries = maxDBRetries; this.queue = new LinkedBlockingQueue<>(maxQSize); this.timeout = qTimeoutSecs; this.decay = decay; clusterCounts = new TreeMap<>(); this.useDBTime = useDBTime; logger.info("Async cluster count created for client "+client+" qTimeout:"+qTimeoutSecs+" batchSize:"+batchSize+" maxQSize:"+maxQSize+" maxDBRetries:"+maxDBRetries+" decay:"+decay+" use DB Time:"+useDBTime); } public void run() { keepRunning = true; this.lastSqlRunTime = System.currentTimeMillis(); while (true) { try { ClusterCount count = queue.poll(timeout, TimeUnit.SECONDS); if (count != null) { if (useDBTime) addCount(count); else addSQL(count); } long timeSinceLastSQLRun = (System.currentTimeMillis() - this.lastSqlRunTime)/1000; boolean runSQL = false; if ((count == null && countsAdded > 0)) { runSQL = true; logger.info("Run sql as timeout on poll and actionsAdded > 0"); } else if (countsAdded >= batchSize) { runSQL = true; logger.info("Run sql as batch size exceeded"); } else if (timeSinceLastSQLRun > timeout && countsAdded > 0) { runSQL = true; logger.info("Run sql as time between sql runs exceeded"); } if (runSQL) runSQL(); if (!keepRunning && count == null) { logger.warn("Asked to stop as keepRunning is false"); return; } } catch (InterruptedException e) { logger.error("Received interrupted exception - will stop",e); return; } catch (Exception e) { logger.error("Caught exception while running ", e); resetState(); logger.warn("\\-> Reset buffers."); } catch (Throwable t) { logger.error("Caught throwable while running ", t); resetState(); logger.warn("\\-> Reset buffers."); } } } private void resetState() { clusterCounts = new TreeMap<>(); clearSQLState(); countsAdded = 0; countsAddedTotal = 0; this.lastSqlRunTime = System.currentTimeMillis(); } private void clearSQLState() { try { if (connection != null) { try{connection.close();} catch( SQLException exception ) { logger.error("Unable to close connection",exception); } } if (countPreparedStatement != null) { try{countPreparedStatement.close();} catch( SQLException exception ) { logger.error("Unable to close action perpared statment",exception); } } } finally { connection = null; countPreparedStatement = null; } } private void executeBatch() throws SQLException { if (countsAdded > 0) { countPreparedStatement.executeBatch(); countPreparedStatement.close(); countsAdded = 0; connection.commit(); } } private void rollBack() { try { connection.rollback(); } catch( SQLException re ) { logger.error("Can't roll back transaction",re); } } private void runSQL() throws SQLException { int sqlAdded = 0; int localActionsAdded = this.countsAdded; if (useDBTime) { addSQLs(); sqlAdded = this.countsAdded; localActionsAdded = this.countsAddedTotal; } else { sqlAdded = this.countsAdded; localActionsAdded = this.countsAdded; } long t1 = System.currentTimeMillis(); boolean success = false; for (int i = 0; i < this.maxDBRetries; i++) { try { executeBatch(); success = true; break; } catch (SQLException e) { logger.error("Failed to run update ",e); } } if (!success) { rollBack(); localActionsAdded = 0; } resetState(); long t2 = System.currentTimeMillis(); //log q size float compression = localActionsAdded > 0 ? (1.0f-(sqlAdded/(float)localActionsAdded)) : 0.0f; logger.info("Asyn count for "+client+" at size:"+queue.size()+" actions added "+localActionsAdded+" unique sql inserts "+sqlAdded + " compression " + compression +" time to process:"+(t2-t1)); } /** * Allowed operations to fill in nulls in Action * @param action */ private void repairAction(Action action) { if (action.getTimes() == null) action.setTimes(1); } private void getConnectionIfNeeded() throws SQLException { if (connection == null) { connection = JDBCConnectionFactory.get().getConnection(client); connection.setAutoCommit( false ); } } private void addCount(ClusterCount count) { TreeMap<Long,Double> clusterMap = clusterCounts.get(count.clusterId); if (clusterMap == null) { clusterMap = new TreeMap<>(); clusterMap.put(count.itemId, count.weight); clusterCounts.put(count.clusterId, clusterMap); countsAdded++; } else { Double presValue = clusterMap.get(count.itemId); if (presValue == null) { clusterMap.put(count.itemId, count.weight); countsAdded++; } else clusterMap.put(count.itemId, presValue + count.weight); } this.countsAddedTotal++; } private synchronized int addSQLs() throws SQLException { getConnectionIfNeeded(); // Add action batch if (countPreparedStatement == null) countPreparedStatement = connection.prepareStatement("insert into cluster_counts values (?,?,?,unix_timestamp()) on duplicate key update count=?+exp(-(greatest(unix_timestamp()-t,0)/?))*count,t=unix_timestamp()"); int added = 0; for(Map.Entry<Integer,TreeMap<Long,Double>> m : clusterCounts.entrySet()) { for(Map.Entry<Long, Double> e : m.getValue().entrySet()) { countPreparedStatement.setInt(1, m.getKey()); countPreparedStatement.setLong(2, e.getKey()); countPreparedStatement.setDouble(3, e.getValue()); countPreparedStatement.setDouble(4, e.getValue()); countPreparedStatement.setDouble(5, decay); countPreparedStatement.addBatch(); added++; } } logger.info("Added "+added+" sql inserts to run "); clusterCounts = new TreeMap<>(); return added; } private synchronized void addActionBatch(ClusterCount count) throws SQLException { long time = System.currentTimeMillis(); countPreparedStatement.setInt(1, count.clusterId); countPreparedStatement.setLong(2, count.itemId); countPreparedStatement.setDouble(3, count.weight); countPreparedStatement.setLong(4, time); countPreparedStatement.setDouble(5, count.weight); countPreparedStatement.setLong(6, time); countPreparedStatement.setDouble(7, decay); countPreparedStatement.setLong(8, time); } private void addSQL(ClusterCount count) throws SQLException { getConnectionIfNeeded(); // Add action batch if (countPreparedStatement == null) countPreparedStatement = connection.prepareStatement("insert into cluster_counts values (?,?,?,?) on duplicate key update count=?+exp(-(greatest(?-t,0)/?))*count,t=?"); addActionBatch(count); countPreparedStatement.addBatch(); countsAdded++; } public void put(ClusterCount count) { queue.add(count); } public int getQSize() { return queue.size(); } public String getClient() { return client; } public void setClient(String client) { this.client = client; } public int getTimeout() { return timeout; } public void setTimeout(int timeout) { this.timeout = timeout; } public int getActionsAdded() { return countsAdded; } public int getBatchSize() { return batchSize; } public void setBatchSize(int batchSize) { this.batchSize = batchSize; } public int getMaxDBRetries() { return maxDBRetries; } public void setMaxDBRetries(int maxDBRetries) { this.maxDBRetries = maxDBRetries; } public boolean isKeepRunning() { return keepRunning; } public void setKeepRunning(boolean keepRunning) { this.keepRunning = keepRunning; } public int getBadActions() { return badActions; } public double getDecay() { return decay; } public synchronized void setDecay(double decay) { this.decay = decay; } }