package com.lucidworks.storm.solr; import backtype.storm.generated.GlobalStreamId; import backtype.storm.grouping.CustomStreamGrouping; import backtype.storm.task.WorkerTopologyContext; import org.apache.commons.math3.distribution.UniformIntegerDistribution; import org.apache.solr.common.cloud.*; import org.apache.solr.common.util.Hash; import java.io.Serializable; import java.util.*; public class HashRangeGrouping implements CustomStreamGrouping, Serializable { private transient List<Integer> targetTasks; private transient List<DocRouter.Range> ranges; protected Map stormConf; protected int numShards; protected UniformIntegerDistribution random; protected int tasksPerShard; public HashRangeGrouping(Map stormConf, int numShards) { this.stormConf = stormConf; this.numShards = numShards; } public int getNumShards() { return numShards; } public void prepare(WorkerTopologyContext context, GlobalStreamId stream, List<Integer> targetTasks) { this.targetTasks = targetTasks; int numTasks = targetTasks.size(); if (numTasks % numShards != 0) throw new IllegalArgumentException("Number of tasks ("+numTasks+") should be a multiple of the number of shards ("+numShards+")!"); this.tasksPerShard = numTasks/numShards; this.random = new UniformIntegerDistribution(0, tasksPerShard-1); CompositeIdRouter docRouter = new CompositeIdRouter(); this.ranges = docRouter.partitionRange(numShards, docRouter.fullRange()); } public List<Integer> chooseTasks(int taskId, List<Object> values) { if (values == null || values.size() < 1) return Collections.singletonList(targetTasks.get(0)); String docId = (String) values.get(0); if (docId == null) return Collections.singletonList(targetTasks.get(0)); final int hash = Hash.murmurhash3_x86_32(docId, 0, docId.length(), 0); int rangeIndex = 0; for (int r=0; r < ranges.size(); r++) { if (ranges.get(r).includes(hash)) { rangeIndex = r; break; } } int selectedTask = (tasksPerShard > 1) ? rangeIndex + (random.sample() * tasksPerShard) : rangeIndex; return Collections.singletonList(targetTasks.get(selectedTask)); } }