package com.lucidworks.storm.solr; import java.io.Serializable; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import backtype.storm.generated.GlobalStreamId; import backtype.storm.grouping.CustomStreamGrouping; import backtype.storm.task.WorkerTopologyContext; import com.lucidworks.storm.StreamingApp; import org.apache.solr.client.solrj.impl.CloudSolrClient; import org.apache.commons.math3.distribution.UniformIntegerDistribution; import org.apache.solr.common.cloud.DocCollection; import org.apache.solr.common.cloud.DocRouter; import org.apache.solr.common.cloud.ImplicitDocRouter; import org.apache.solr.common.cloud.Slice; public class ShardGrouping implements CustomStreamGrouping, Serializable { private transient List<Integer> targetTasks; private transient CloudSolrClient cloudSolrClient; private transient DocCollection docCollection; private transient Map<String, Integer> shardIndexCache; protected Map stormConf; protected String collection; protected Integer numShards; protected UniformIntegerDistribution random; protected int tasksPerShard; public ShardGrouping(Map stormConf, String collection) { this.stormConf = stormConf; this.collection = collection; } public void setCloudSolrClient(CloudSolrClient client) { cloudSolrClient = client; } public int getNumShards() { if (numShards == null) numShards = new Integer(initShardInfo()); return numShards.intValue(); } public void prepare(WorkerTopologyContext context, GlobalStreamId stream, List<Integer> targetTasks) { this.targetTasks = targetTasks; int numTasks = targetTasks.size(); int numShards = initShardInfo(); // setup for doing shard to task mapping if (numTasks % numShards != 0) throw new IllegalArgumentException("Number of tasks ("+numTasks+") should be a multiple of the number of shards ("+numShards+")!"); this.tasksPerShard = numTasks/numShards; this.random = new UniformIntegerDistribution(0, tasksPerShard-1); } public List<Integer> chooseTasks(int taskId, List<Object> values) { if (values == null || values.size() < 1) return Collections.singletonList(targetTasks.get(0)); String docId = (String) values.get(0); if (docId == null) return Collections.singletonList(targetTasks.get(0)); Slice slice = docCollection.getRouter().getTargetSlice(docId, null, null, null, docCollection); // map this doc into one of the tasks for that shard int shardIndex = shardIndexCache.get(slice.getName()); int selectedTask = (tasksPerShard > 1) ? shardIndex + (random.sample() * tasksPerShard) : shardIndex; return Collections.singletonList(targetTasks.get(selectedTask)); } protected int initShardInfo() { if (cloudSolrClient == null) { // lookup the Solr client from the Spring context for this topology cloudSolrClient = (CloudSolrClient) StreamingApp.spring(stormConf).getBean("cloudSolrClient"); cloudSolrClient.connect(); } this.docCollection = cloudSolrClient.getZkStateReader().getClusterState().getCollection(collection); // do basic checks once DocRouter docRouter = docCollection.getRouter(); if (docRouter instanceof ImplicitDocRouter) throw new IllegalStateException("Implicit document routing not supported by this Partitioner!"); Collection<Slice> shards = docCollection.getSlices(); if (shards == null || shards.size() == 0) throw new IllegalStateException("Collection '" + collection + "' does not have any shards!"); shardIndexCache = new HashMap<String, Integer>(20); int s = 0; for (Slice next : shards) shardIndexCache.put(next.getName(), s++); return shards.size(); } }