/** * Copyright (C) 2010-2013 Alibaba Group Holding Limited * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.alibaba.rocketmq.client.impl.consumer; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import org.slf4j.Logger; import com.alibaba.rocketmq.client.consumer.AllocateMessageQueueStrategy; import com.alibaba.rocketmq.client.impl.FindBrokerResult; import com.alibaba.rocketmq.client.impl.factory.MQClientFactory; import com.alibaba.rocketmq.client.log.ClientLogger; import com.alibaba.rocketmq.common.MixAll; import com.alibaba.rocketmq.common.message.MessageQueue; import com.alibaba.rocketmq.common.protocol.body.LockBatchRequestBody; import com.alibaba.rocketmq.common.protocol.body.UnlockBatchRequestBody; import com.alibaba.rocketmq.common.protocol.heartbeat.MessageModel; import com.alibaba.rocketmq.common.protocol.heartbeat.SubscriptionData; /** * Rebalance的具体实现 * * @author shijia.wxr<vintage.wang@gmail.com> * @since 2013-6-22 */ public abstract class RebalanceImpl { protected static final Logger log = ClientLogger.getLog(); // 分配好的队列,消息存储也在这里 protected final ConcurrentHashMap<MessageQueue, ProcessQueue> processQueueTable = new ConcurrentHashMap<MessageQueue, ProcessQueue>(64); // 可以订阅的所有队列(定时从Name Server更新最新版本) protected final ConcurrentHashMap<String/* topic */, Set<MessageQueue>> topicSubscribeInfoTable = new ConcurrentHashMap<String, Set<MessageQueue>>(); // 订阅关系,用户配置的原始数据 protected final ConcurrentHashMap<String /* topic */, SubscriptionData> subscriptionInner = new ConcurrentHashMap<String, SubscriptionData>(); protected String consumerGroup; protected MessageModel messageModel; protected AllocateMessageQueueStrategy allocateMessageQueueStrategy; protected MQClientFactory mQClientFactory; public RebalanceImpl(String consumerGroup, MessageModel messageModel, AllocateMessageQueueStrategy allocateMessageQueueStrategy, MQClientFactory mQClientFactory) { this.consumerGroup = consumerGroup; this.messageModel = messageModel; this.allocateMessageQueueStrategy = allocateMessageQueueStrategy; this.mQClientFactory = mQClientFactory; } public void unlock(final MessageQueue mq, final boolean oneway) { FindBrokerResult findBrokerResult = this.mQClientFactory.findBrokerAddressInSubscribe(mq.getBrokerName(), MixAll.MASTER_ID, true); if (findBrokerResult != null) { UnlockBatchRequestBody requestBody = new UnlockBatchRequestBody(); requestBody.setConsumerGroup(this.consumerGroup); requestBody.setClientId(this.mQClientFactory.getClientId()); requestBody.getMqSet().add(mq); try { this.mQClientFactory.getMQClientAPIImpl().unlockBatchMQ(findBrokerResult.getBrokerAddr(), requestBody, 1000, oneway); log.warn("unlock messageQueue. group:{}, clientId:{}, mq:{}",// this.consumerGroup, // this.mQClientFactory.getClientId(), // mq); } catch (Exception e) { log.error("unlockBatchMQ exception, " + mq, e); } } } public void unlockAll(final boolean oneway) { HashMap<String, Set<MessageQueue>> brokerMqs = this.buildProcessQueueTableByBrokerName(); for (final Map.Entry<String, Set<MessageQueue>> entry : brokerMqs.entrySet()) { final String brokerName = entry.getKey(); final Set<MessageQueue> mqs = entry.getValue(); if (mqs.isEmpty()) continue; FindBrokerResult findBrokerResult = this.mQClientFactory.findBrokerAddressInSubscribe(brokerName, MixAll.MASTER_ID, true); if (findBrokerResult != null) { UnlockBatchRequestBody requestBody = new UnlockBatchRequestBody(); requestBody.setConsumerGroup(this.consumerGroup); requestBody.setClientId(this.mQClientFactory.getClientId()); requestBody.setMqSet(mqs); try { this.mQClientFactory.getMQClientAPIImpl().unlockBatchMQ(findBrokerResult.getBrokerAddr(), requestBody, 1000, oneway); for (MessageQueue mq : mqs) { ProcessQueue processQueue = this.processQueueTable.get(mq); if (processQueue != null) { processQueue.setLocked(false); log.info("the message queue unlock OK, Group: {} {}", this.consumerGroup, mq); } } } catch (Exception e) { log.error("unlockBatchMQ exception, " + mqs, e); } } } } private HashMap<String/* brokerName */, Set<MessageQueue>> buildProcessQueueTableByBrokerName() { HashMap<String, Set<MessageQueue>> result = new HashMap<String, Set<MessageQueue>>(); for (MessageQueue mq : this.processQueueTable.keySet()) { Set<MessageQueue> mqs = result.get(mq.getBrokerName()); if (null == mqs) { mqs = new HashSet<MessageQueue>(); result.put(mq.getBrokerName(), mqs); } mqs.add(mq); } return result; } public boolean lock(final MessageQueue mq) { FindBrokerResult findBrokerResult = this.mQClientFactory.findBrokerAddressInSubscribe(mq.getBrokerName(), MixAll.MASTER_ID, true); if (findBrokerResult != null) { LockBatchRequestBody requestBody = new LockBatchRequestBody(); requestBody.setConsumerGroup(this.consumerGroup); requestBody.setClientId(this.mQClientFactory.getClientId()); requestBody.getMqSet().add(mq); try { Set<MessageQueue> lockedMq = this.mQClientFactory.getMQClientAPIImpl().lockBatchMQ( findBrokerResult.getBrokerAddr(), requestBody, 1000); for (MessageQueue mmqq : lockedMq) { ProcessQueue processQueue = this.processQueueTable.get(mmqq); if (processQueue != null) { processQueue.setLocked(true); processQueue.setLastLockTimestamp(System.currentTimeMillis()); } } boolean lockOK = lockedMq.contains(mq); log.info("the message queue lock {}, {} {}",// (lockOK ? "OK" : "Failed"), // this.consumerGroup, // mq); return lockOK; } catch (Exception e) { log.error("lockBatchMQ exception, " + mq, e); } } return false; } public void lockAll() { HashMap<String, Set<MessageQueue>> brokerMqs = this.buildProcessQueueTableByBrokerName(); Iterator<Entry<String, Set<MessageQueue>>> it = brokerMqs.entrySet().iterator(); while (it.hasNext()) { Entry<String, Set<MessageQueue>> entry = it.next(); final String brokerName = entry.getKey(); final Set<MessageQueue> mqs = entry.getValue(); if (mqs.isEmpty()) continue; FindBrokerResult findBrokerResult = this.mQClientFactory.findBrokerAddressInSubscribe(brokerName, MixAll.MASTER_ID, true); if (findBrokerResult != null) { LockBatchRequestBody requestBody = new LockBatchRequestBody(); requestBody.setConsumerGroup(this.consumerGroup); requestBody.setClientId(this.mQClientFactory.getClientId()); requestBody.setMqSet(mqs); try { Set<MessageQueue> lockOKMQSet = this.mQClientFactory.getMQClientAPIImpl().lockBatchMQ( findBrokerResult.getBrokerAddr(), requestBody, 1000); // 锁定成功的队列 for (MessageQueue mq : lockOKMQSet) { ProcessQueue processQueue = this.processQueueTable.get(mq); if (processQueue != null) { if (!processQueue.isLocked()) { log.info("the message queue locked OK, Group: {} {}", this.consumerGroup, mq); } processQueue.setLocked(true); processQueue.setLastLockTimestamp(System.currentTimeMillis()); } } // 锁定失败的队列 for (MessageQueue mq : mqs) { if (!lockOKMQSet.contains(mq)) { ProcessQueue processQueue = this.processQueueTable.get(mq); if (processQueue != null) { processQueue.setLocked(false); log.warn("the message queue locked Failed, Group: {} {}", this.consumerGroup, mq); } } } } catch (Exception e) { log.error("lockBatchMQ exception, " + mqs, e); } } } } public void doRebalance() { Map<String, SubscriptionData> subTable = this.getSubscriptionInner(); if (subTable != null) { /* * chen.si 针对每个topic,检查当前consumer group中的 活动consumer list是否有变化,如果变化了,需要重新调整 分区 */ for (final Map.Entry<String, SubscriptionData> entry : subTable.entrySet()) { final String topic = entry.getKey(); try { this.rebalanceByTopic(topic); } catch (Exception e) { if (!topic.startsWith(MixAll.RETRY_GROUP_TOPIC_PREFIX)) { log.warn("rebalanceByTopic Exception", e); } } } } this.truncateMessageQueueNotMyTopic(); } private void rebalanceByTopic(final String topic) { switch (messageModel) { case BROADCASTING: { Set<MessageQueue> mqSet = this.topicSubscribeInfoTable.get(topic); if (mqSet != null) { boolean changed = this.updateProcessQueueTableInRebalance(topic, mqSet); if (changed) { this.messageQueueChanged(topic, mqSet, mqSet); log.info("messageQueueChanged {} {} {} {}",// consumerGroup,// topic,// mqSet,// mqSet); } } else { log.warn("doRebalance, {}, but the topic[{}] not exist.", consumerGroup, topic); } break; } case CLUSTERING: { /* * chen.si topic下的所有的分区队列 */ Set<MessageQueue> mqSet = this.topicSubscribeInfoTable.get(topic); /* * chen.si 一个consumer,会有一个唯一的clientId。这里通过clientId来标识 同一个consumer group下的当前活动的所有consumer。 */ List<String> cidAll = this.mQClientFactory.findConsumerIdList(topic, consumerGroup); if (null == mqSet) { if (!topic.startsWith(MixAll.RETRY_GROUP_TOPIC_PREFIX)) { log.warn("doRebalance, {}, but the topic[{}] not exist.", consumerGroup, topic); } } if (null == cidAll) { log.warn("doRebalance, {} {}, get consumer id list failed", consumerGroup, topic); } if (mqSet != null && cidAll != null) { List<MessageQueue> mqAll = new ArrayList<MessageQueue>(); mqAll.addAll(mqSet); // 排序 Collections.sort(mqAll); Collections.sort(cidAll); AllocateMessageQueueStrategy strategy = this.allocateMessageQueueStrategy; // 执行分配算法 List<MessageQueue> allocateResult = null; try { allocateResult = strategy.allocate(this.mQClientFactory.getClientId(), mqAll, cidAll); } catch (Throwable e) { log.error("AllocateMessageQueueStrategy.allocate Exception", e); } Set<MessageQueue> allocateResultSet = new HashSet<MessageQueue>(); if (allocateResult != null) { allocateResultSet.addAll(allocateResult); } // 更新本地队列 boolean changed = this.updateProcessQueueTableInRebalance(topic, allocateResultSet); if (changed) { log.info("rebalanced result changed. mqSet={}, ConsumerId={}, mqSize={}, cidSize={}", allocateResult, this.mQClientFactory.getClientId(), mqAll.size(), cidAll.size()); /* * chen.si 当前consumer负责的分区队列有变化,需要通知consumer。 当前consumer需要重新调整fetch */ this.messageQueueChanged(topic, mqSet, allocateResultSet); log.info("messageQueueChanged {} {} {} {}",// consumerGroup,// topic,// mqSet,// allocateResultSet); log.info("messageQueueChanged consumerIdList: {}",// cidAll); } } break; } default: break; } } public abstract void messageQueueChanged(final String topic, final Set<MessageQueue> mqAll, final Set<MessageQueue> mqDivided); private boolean updateProcessQueueTableInRebalance(final String topic, final Set<MessageQueue> mqSet) { boolean changed = false; // 将多余的队列删除 Iterator<Entry<MessageQueue, ProcessQueue>> it = this.processQueueTable.entrySet().iterator(); while (it.hasNext()) { Entry<MessageQueue, ProcessQueue> next = it.next(); MessageQueue mq = next.getKey(); ProcessQueue pq = next.getValue(); if (mq.getTopic().equals(topic)) { if (!mqSet.contains(mq)) { changed = true; it.remove(); pq.setDroped(true); log.info("doRebalance, {}, remove unnecessary mq, {}", consumerGroup, mq); this.removeUnnecessaryMessageQueue(mq, pq); } } } // 增加新增的队列 List<PullRequest> pullRequestList = new ArrayList<PullRequest>(); for (MessageQueue mq : mqSet) { if (!this.processQueueTable.containsKey(mq)) { PullRequest pullRequest = new PullRequest(); pullRequest.setConsumerGroup(consumerGroup); pullRequest.setMessageQueue(mq); pullRequest.setProcessQueue(new ProcessQueue()); // 这个需要根据策略来设置 long nextOffset = this.computePullFromWhere(mq); if (nextOffset >= 0) { pullRequest.setNextOffset(nextOffset); pullRequestList.add(pullRequest); changed = true; this.processQueueTable.put(mq, pullRequest.getProcessQueue()); log.info("doRebalance, {}, add a new mq, {}", consumerGroup, mq); } else { // 等待此次Rebalance做重试 log.warn("doRebalance, {}, add new mq failed, {}", consumerGroup, mq); } } } this.dispatchPullRequest(pullRequestList); return changed; } public abstract void removeUnnecessaryMessageQueue(final MessageQueue mq, final ProcessQueue pq); public abstract void dispatchPullRequest(final List<PullRequest> pullRequestList); public abstract long computePullFromWhere(final MessageQueue mq); private void truncateMessageQueueNotMyTopic() { Map<String, SubscriptionData> subTable = this.getSubscriptionInner(); for (MessageQueue mq : this.processQueueTable.keySet()) { if (!subTable.containsKey(mq.getTopic())) { ProcessQueue pq = this.processQueueTable.remove(mq); if (pq != null) { pq.setDroped(true); log.info("doRebalance, {}, truncateMessageQueueNotMyTopic remove unnecessary mq, {}", consumerGroup, mq); } } } } public ConcurrentHashMap<String, SubscriptionData> getSubscriptionInner() { return subscriptionInner; } public ConcurrentHashMap<MessageQueue, ProcessQueue> getProcessQueueTable() { return processQueueTable; } public ConcurrentHashMap<String, Set<MessageQueue>> getTopicSubscribeInfoTable() { return topicSubscribeInfoTable; } public String getConsumerGroup() { return consumerGroup; } public void setConsumerGroup(String consumerGroup) { this.consumerGroup = consumerGroup; } public MessageModel getMessageModel() { return messageModel; } public void setMessageModel(MessageModel messageModel) { this.messageModel = messageModel; } public AllocateMessageQueueStrategy getAllocateMessageQueueStrategy() { return allocateMessageQueueStrategy; } public void setAllocateMessageQueueStrategy(AllocateMessageQueueStrategy allocateMessageQueueStrategy) { this.allocateMessageQueueStrategy = allocateMessageQueueStrategy; } public MQClientFactory getmQClientFactory() { return mQClientFactory; } public void setmQClientFactory(MQClientFactory mQClientFactory) { this.mQClientFactory = mQClientFactory; } }