package edu.brown.markov.containers; import java.io.BufferedReader; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.lang.reflect.Constructor; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.SortedSet; import java.util.TreeSet; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import org.apache.log4j.Logger; import org.json.JSONException; import org.json.JSONObject; import org.json.JSONStringer; import org.voltdb.CatalogContext; import org.voltdb.catalog.Database; import org.voltdb.catalog.Procedure; import org.voltdb.utils.Pair; import edu.brown.catalog.CatalogUtil; import edu.brown.hashing.AbstractHasher; import edu.brown.logging.LoggerUtil; import edu.brown.logging.LoggerUtil.LoggerBoolean; import edu.brown.markov.MarkovGraph; import edu.brown.markov.MarkovUtil; import edu.brown.statistics.ObjectHistogram; import edu.brown.utils.ClassUtil; import edu.brown.utils.CollectionUtil; import edu.brown.utils.FileUtil; import edu.brown.utils.PartitionEstimator; import edu.brown.utils.ThreadUtil; import edu.brown.workload.TransactionTrace; import edu.brown.workload.Workload; public abstract class MarkovGraphsContainerUtil { private static final Logger LOG = Logger.getLogger(MarkovGraphsContainerUtil.class); private static final LoggerBoolean debug = new LoggerBoolean(); private static final LoggerBoolean trace = new LoggerBoolean(); static { LoggerUtil.attachObserver(LOG, debug, trace); } // ---------------------------------------------------------------------------- // INSTANTATION METHODS // ---------------------------------------------------------------------------- /** * Create the MarkovGraphsContainers for the given workload * @param catalog_db * @param workload * @param p_estimator * @param containerClass * @return * @throws Exception */ public static <T extends MarkovGraphsContainer> Map<Integer, MarkovGraphsContainer> createMarkovGraphsContainers(final CatalogContext catalogContext, final Workload workload, final PartitionEstimator p_estimator, final Class<T> containerClass) throws Exception { final Map<Integer, MarkovGraphsContainer> markovs_map = new ConcurrentHashMap<Integer, MarkovGraphsContainer>(); return createMarkovGraphsContainers(catalogContext, workload, p_estimator, containerClass, markovs_map); } /** * Create the MarkovGraphsContainers for the given workload. * The markovs_map could contain an existing collection MarkovGraphsContainers * @param catalog_db * @param workload * @param p_estimator * @param containerClass * @param markovs_map * @return * @throws Exception */ @SuppressWarnings("unchecked") public static <T extends MarkovGraphsContainer> Map<Integer, MarkovGraphsContainer> createMarkovGraphsContainers(final CatalogContext catalogContext, final Workload workload, final PartitionEstimator p_estimator, final Class<T> containerClass, final Map<Integer, MarkovGraphsContainer> markovs_map) throws Exception { final String className = containerClass.getSimpleName(); final Database catalog_db = catalogContext.database; final List<Runnable> runnables = new ArrayList<Runnable>(); final Set<Procedure> procedures = workload.getProcedures(catalog_db); final ObjectHistogram<Procedure> proc_h = new ObjectHistogram<Procedure>(); final int num_transactions = workload.getTransactionCount(); final int marker = Math.max(1, (int)(num_transactions * 0.10)); final AtomicInteger finished_ctr = new AtomicInteger(0); final AtomicInteger txn_ctr = new AtomicInteger(0); final int num_threads = ThreadUtil.getMaxGlobalThreads(); final Constructor<T> constructor = ClassUtil.getConstructor(containerClass, new Class<?>[]{Collection.class}); final boolean is_global = containerClass.equals(GlobalMarkovGraphsContainer.class); final List<Thread> processing_threads = new ArrayList<Thread>(); final LinkedBlockingDeque<Pair<Integer, TransactionTrace>> queues[] = (LinkedBlockingDeque<Pair<Integer, TransactionTrace>>[])new LinkedBlockingDeque<?>[num_threads]; for (int i = 0; i < num_threads; i++) { queues[i] = new LinkedBlockingDeque<Pair<Integer, TransactionTrace>>(); } // FOR // QUEUING THREAD final AtomicBoolean queued_all = new AtomicBoolean(false); runnables.add(new Runnable() { @Override public void run() { List<TransactionTrace> all_txns = new ArrayList<TransactionTrace>(workload.getTransactions()); Collections.shuffle(all_txns); int ctr = 0; for (TransactionTrace txn_trace : all_txns) { // Make sure it goes to the right base partition Integer partition = null; try { partition = p_estimator.getBasePartition(txn_trace); } catch (Exception ex) { throw new RuntimeException(ex); } assert(partition != null) : "Failed to get base partition for " + txn_trace + "\n" + txn_trace.debug(catalog_db); queues[ctr % num_threads].add(Pair.of(partition, txn_trace)); if (++ctr % marker == 0 && debug.val) LOG.debug(String.format("Queued %d/%d transactions", ctr, num_transactions)); } // FOR queued_all.set(true); // Poke all our threads just in case they finished for (Thread t : processing_threads) { if (t != null) t.interrupt(); } // FOR } }); // PROCESSING THREADS for (int i = 0; i < num_threads; i++) { final int thread_id = i; runnables.add(new Runnable() { @Override public void run() { Thread self = Thread.currentThread(); processing_threads.add(self); MarkovGraphsContainer markovs = null; Pair<Integer, TransactionTrace> pair = null; while (true) { try { if (queued_all.get()) { pair = queues[thread_id].poll(); } else { pair = queues[thread_id].take(); // Steal work if (pair == null) { for (int i = 0; i < num_threads; i++) { if (i == thread_id) continue; pair = queues[i].take(); if (pair != null) break; } // FOR } } } catch (InterruptedException ex) { continue; } if (pair == null) break; int partition = pair.getFirst(); TransactionTrace txn_trace = pair.getSecond(); Procedure catalog_proc = txn_trace.getCatalogItem(catalog_db); long txn_id = txn_trace.getTransactionId(); try { int map_id = (is_global ? MarkovUtil.GLOBAL_MARKOV_CONTAINER_ID : partition); Object params[] = txn_trace.getParams(); markovs = markovs_map.get(map_id); if (markovs == null) { synchronized (markovs_map) { markovs = markovs_map.get(map_id); if (markovs == null) { markovs = constructor.newInstance(new Object[]{procedures}); markovs.setHasher(p_estimator.getHasher()); markovs_map.put(map_id, markovs); } } // SYNCH } MarkovGraph markov = markovs.getFromParams(txn_id, map_id, params, catalog_proc); synchronized (markov) { markov.processTransaction(txn_trace, p_estimator); } // SYNCH } catch (Exception ex) { LOG.fatal("Failed to process " + txn_trace, ex); throw new RuntimeException(ex); } proc_h.put(catalog_proc); int global_ctr = txn_ctr.incrementAndGet(); if (debug.val && global_ctr % marker == 0) { LOG.debug(String.format("Processed %d/%d transactions", global_ctr, num_transactions)); } } // FOR LOG.info(String.format("Processing thread finished creating %s [%d/%d]", className, finished_ctr.incrementAndGet(), num_threads)); } }); } // FOR LOG.info(String.format("Generating %s for %d partitions using %d threads", className, catalogContext.numberOfPartitions, num_threads)); ThreadUtil.runGlobalPool(runnables); proc_h.setDebugLabels(CatalogUtil.getDisplayNameMapping(proc_h.values())); LOG.info("Procedure Histogram:\n" + proc_h); MarkovGraphsContainerUtil.calculateProbabilities(catalogContext, markovs_map); return (markovs_map); } /** * Construct all of the Markov graphs for a workload+catalog split by the txn's base partition * @param catalog_db * @param workload * @param p_estimator * @return */ public static MarkovGraphsContainer createBasePartitionMarkovGraphsContainer(final CatalogContext catalogContext, final Workload workload, final PartitionEstimator p_estimator) { assert(workload != null); assert(p_estimator != null); Map<Integer, MarkovGraphsContainer> markovs_map = null; try { markovs_map = createMarkovGraphsContainers(catalogContext, workload, p_estimator, MarkovGraphsContainer.class); } catch (Exception ex) { throw new RuntimeException(ex); } assert(markovs_map != null); // Combine in a single Container final MarkovGraphsContainer combined = new MarkovGraphsContainer(); for (Integer p : markovs_map.keySet()) { combined.copy(markovs_map.get(p)); } // FOR return (combined); } /** * Combine multiple MarkovGraphsContainer files into a single file * @param markovs * @param output_path */ public static void combine(Map<Integer, File> markovs, File file, Database catalog_db) { // Sort the list of partitions so we always iterate over them in the same order SortedSet<Integer> sorted = new TreeSet<Integer>(markovs.keySet()); // We want all the procedures Collection<Procedure> procedures = CollectionUtil.addAll(new HashSet<Procedure>(), catalog_db.getProcedures()); try { FileOutputStream out = new FileOutputStream(file); // First construct an index that allows us to quickly find the partitions that we want JSONStringer stringer = (JSONStringer)(new JSONStringer().object()); int offset = 1; for (Integer partition : sorted) { stringer.key(partition.toString()).value(offset++); } // FOR out.write((stringer.endObject().toString() + "\n").getBytes()); // Now Loop through each file individually so that we only have to load one into memory at a time for (Integer partition : sorted) { File in = markovs.get(partition); try { JSONObject json_object = new JSONObject(FileUtil.readFile(in)); MarkovGraphsContainer markov = MarkovGraphsContainerUtil.createMarkovGraphsContainer(json_object, procedures, catalog_db); markov.load(in, catalog_db); stringer = (JSONStringer)new JSONStringer().object(); stringer.key(partition.toString()).object(); markov.toJSON(stringer); stringer.endObject().endObject(); out.write((stringer.toString() + "\n").getBytes()); } catch (Exception ex) { throw new Exception(String.format("Failed to copy MarkovGraphsContainer for partition %d from '%s'", partition, in), ex); } } // FOR out.close(); } catch (Exception ex) { String msg = String.format("Failed to combine multiple %s into file '%s'", MarkovGraphsContainer.class.getSimpleName(), file.getAbsolutePath()); LOG.error(msg, ex); throw new RuntimeException(msg, ex); } LOG.info(String.format("Combined %d %s into file '%s'", markovs.size(), MarkovGraphsContainer.class.getSimpleName(), file.getAbsolutePath())); } /** * * @param json_object * @param procedures * @param catalog_db * @return * @throws JSONException */ public static MarkovGraphsContainer createMarkovGraphsContainer(JSONObject json_object, Collection<Procedure> procedures, Database catalog_db) throws JSONException { // We should be able to get the classname of the container from JSON String className = MarkovGraphsContainer.class.getCanonicalName(); if (json_object.has(MarkovGraphsContainer.Members.CLASSNAME.name())) { className = json_object.getString(MarkovGraphsContainer.Members.CLASSNAME.name()); } MarkovGraphsContainer markovs = ClassUtil.newInstance(className, new Object[]{procedures}, new Class<?>[]{Collection.class}); assert(markovs != null); if (debug.val) LOG.debug(String.format("Instantiated new %s object", markovs.getClass().getSimpleName())); markovs.fromJSON(json_object, catalog_db); return (markovs); } // ---------------------------------------------------------------------------- // SAVE TO FILE // ---------------------------------------------------------------------------- /** * For the given MarkovGraphContainer, serialize them out to a file * @param markovs * @param output_path * @throws Exception */ public static void save(Map<Integer, ? extends MarkovGraphsContainer> markovs, File output_path) { final String className = CollectionUtil.first(markovs.values()).getClass().getSimpleName(); // Sort the list of partitions so we always iterate over them in the same order SortedSet<Integer> sorted = new TreeSet<Integer>(markovs.keySet()); int graphs_ctr = 0; try { FileOutputStream out = new FileOutputStream(output_path); // First construct an index that allows us to quickly find the partitions that we want JSONStringer stringer = (JSONStringer)(new JSONStringer().object()); int offset = 1; for (Integer partition : sorted) { stringer.key(Integer.toString(partition)).value(offset++); } // FOR out.write((stringer.endObject().toString() + "\n").getBytes()); // Now roll through each id and create a single JSONObject on each line for (Integer partition : sorted) { MarkovGraphsContainer markov = markovs.get(partition); assert(markov != null) : "Null MarkovGraphsContainer for partition #" + partition; graphs_ctr += markov.totalSize(); stringer = (JSONStringer)new JSONStringer().object(); stringer.key(partition.toString()).object(); markov.toJSON(stringer); stringer.endObject().endObject(); out.write((stringer.toString() + "\n").getBytes()); } // FOR out.close(); } catch (Exception ex) { LOG.error("Failed to serialize the " + className + " file '" + output_path + "'", ex); throw new RuntimeException(ex); } LOG.info(String.format("Wrote out %d graphs in %s to '%s'", graphs_ctr, className, output_path)); } // ---------------------------------------------------------------------------- // LOAD METHODS // ---------------------------------------------------------------------------- public static Map<Integer, MarkovGraphsContainer> load(CatalogContext catalogContext, File input_path) throws Exception { return (MarkovGraphsContainerUtil.load(catalogContext, input_path, null, null)); } public static Map<Integer, MarkovGraphsContainer> loadIds(CatalogContext catalogContext, File input_path, Collection<Integer> ids) throws Exception { return (MarkovGraphsContainerUtil.load(catalogContext, input_path, null, ids)); } public static Map<Integer, MarkovGraphsContainer> loadProcedures(CatalogContext catalogContext, File input_path, Collection<Procedure> procedures) throws Exception { return (MarkovGraphsContainerUtil.load(catalogContext, input_path, procedures, null)); } /** * * @param catalog_db * @param input_path * @param ids * @return * @throws Exception */ public static Map<Integer, MarkovGraphsContainer> load(final CatalogContext catalogContext, final File file, final Collection<Procedure> procedures, final Collection<Integer> ids) throws Exception { final Map<Integer, MarkovGraphsContainer> ret = new HashMap<Integer, MarkovGraphsContainer>(); LOG.info(String.format("Loading in MarkovGraphContainers from '%s' [procedures=%s, ids=%s]", file.getName(), (procedures == null ? "*ALL*" : CatalogUtil.debug(procedures)), (ids == null ? "*ALL*" : ids))); try { // File Format: One PartitionId per line, each with its own MarkovGraphsContainer BufferedReader in = FileUtil.getReader(file); // Line# -> Partition# final Map<Integer, Integer> line_xref = new HashMap<Integer, Integer>(); int line_ctr = 0; while (in.ready()) { final String line = in.readLine(); // If this is the first line, then it is our index if (line_ctr == 0) { // Construct our line->partition mapping JSONObject json_object = new JSONObject(line); for (String key : CollectionUtil.iterable(json_object.keys())) { Integer partition = Integer.valueOf(key); // We want the MarkovGraphContainer pointed to by this line if // (1) This partition is the same as our GLOBAL_MARKOV_CONTAINER_ID, which means that // there isn't going to be partition-specific graphs. There should only be one entry // in this file and we're always going to want to load it // (2) They didn't pass us any ids, so we'll take everything we see // (3) They did pass us ids, so check whether its included in the set if (partition.equals(MarkovUtil.GLOBAL_MARKOV_CONTAINER_ID) || ids == null || ids.contains(partition)) { Integer offset = json_object.getInt(key); line_xref.put(offset, partition); } } // FOR if (debug.val) LOG.debug(String.format("Loading %d MarkovGraphsContainers", line_xref.size())); // Otherwise check whether this is a line number that we care about } else if (line_xref.containsKey(Integer.valueOf(line_ctr))) { Integer partition = line_xref.remove(Integer.valueOf(line_ctr)); JSONObject json_object = new JSONObject(line).getJSONObject(partition.toString()); MarkovGraphsContainer markovs = createMarkovGraphsContainer(json_object, procedures, catalogContext.database); if (debug.val) LOG.debug(String.format("Storing %s for partition %d", markovs.getClass().getSimpleName(), partition)); ret.put(partition, markovs); if (line_xref.isEmpty()) break; } line_ctr++; } // WHILE if (line_ctr == 0) throw new IOException("The MarkovGraphsContainer file '" + file + "' is empty"); } catch (Exception ex) { LOG.error("Failed to deserialize the MarkovGraphsContainer from file '" + file + "'", ex); throw new IOException(ex); } if (debug.val) LOG.debug("The loading of the MarkovGraphsContainer is complete"); return (ret); } // ---------------------------------------------------------------------------- // UTILITY METHODS // ---------------------------------------------------------------------------- /** * Utility method to calculate the probabilities at all of the MarkovGraphsContainers * @param markovs */ public static void calculateProbabilities(CatalogContext catalogContext, Map<Integer, ? extends MarkovGraphsContainer> markovs) { if (debug.val) LOG.debug(String.format("Calculating probabilities for %d ids", markovs.size())); for (MarkovGraphsContainer m : markovs.values()) { m.calculateProbabilities(catalogContext.getAllPartitionIds()); } // FOR return; } /** * Utility method * @param markovs * @param hasher */ public static void setHasher(Map<Integer, ? extends MarkovGraphsContainer> markovs, AbstractHasher hasher) { if (debug.val) LOG.debug(String.format("Setting hasher for for %d ids", markovs.size())); for (MarkovGraphsContainer m : markovs.values()) { m.setHasher(hasher); } // FOR return; } }