/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.jstorm.message.netty;
import backtype.storm.Config;
import backtype.storm.messaging.NettyMessage;
import backtype.storm.messaging.TaskMessage;
import backtype.storm.utils.Utils;
import com.alibaba.jstorm.client.ConfigExtension;
import com.alibaba.jstorm.daemon.worker.Flusher;
import com.alibaba.jstorm.utils.JStormUtils;
import com.alibaba.jstorm.utils.Pair;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import org.apache.commons.lang.builder.ToStringBuilder;
import org.apache.commons.lang.builder.ToStringStyle;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
class NettyClientAsync extends NettyClient {
private static final Logger LOG = LoggerFactory.getLogger(NettyClientAsync.class);
public static final String PREFIX = "Netty-Client-";
protected Flusher flusher;
protected int flushCheckInterval;
private boolean isBackpressureEnable;
// Map<TargetTaskId, remoteAddress>
private volatile Map<Integer, Boolean> targetTasksUnderFlowCtrl;
private Map<Integer, Pair<Lock, Condition>> targetTasksToLocks;
// Map<SourceTask, Map<TargetTask, Cache>>
private volatile Map<Integer, Map<Integer, MessageBuffer>> targetTasksCache;
private int flowCtrlAwaitTime;
private int cacheSize;
private class NettyClientFlush extends Flusher {
private AtomicBoolean isFlushing = new AtomicBoolean(false);
public NettyClientFlush(long flushInterval) {
flushIntervalMs = flushInterval;
}
public void run() {
if (isFlushing.compareAndSet(false, true) && !isClosed()) {
synchronized (writeLock) {
MessageBatch cache = getPendingCaches();
Channel channel = waitForChannelReady();
if (channel != null) {
MessageBatch messageBatch = messageBuffer.drain();
if (messageBatch != null)
cache.add(messageBatch);
flushRequest(channel, cache);
}
}
isFlushing.set(false);
}
}
}
@SuppressWarnings("rawtypes")
NettyClientAsync(Map conf, ChannelFactory factory, String host, int port, ReconnectRunnable reconnector, final Set<Integer> sourceTasks, final Set<Integer> targetTasks) {
super(conf, factory, host, port, reconnector);
clientChannelFactory = factory;
initFlowCtrl(conf, sourceTasks, targetTasks);
flushCheckInterval = Utils.getInt(conf.get(Config.STORM_NETTY_FLUSH_CHECK_INTERVAL_MS), 5);
flusher = new NettyClientFlush(flushCheckInterval);
flusher.start();
start();
}
private void initFlowCtrl(Map conf, Set<Integer> sourceTasks, Set<Integer> targetTasks) {
isBackpressureEnable = ConfigExtension.isBackpressureEnable(conf);
flowCtrlAwaitTime = ConfigExtension.getNettyFlowCtrlWaitTime(conf);
cacheSize = ConfigExtension.getNettyFlowCtrlCacheSize(conf) != null ? ConfigExtension.getNettyFlowCtrlCacheSize(conf) : messageBatchSize;
targetTasksUnderFlowCtrl = new HashMap<>();
targetTasksToLocks = new HashMap<>();
targetTasksCache = new HashMap<>();
for (Integer task : targetTasks) {
targetTasksUnderFlowCtrl.put(task, false);
Lock lock = new ReentrantLock();
targetTasksToLocks.put(task, new Pair<>(lock, lock.newCondition()));
}
Set<Integer> tasks = new HashSet<Integer>(sourceTasks);
tasks.add(0); // add task-0 as default source task
for (Integer sourceTask : tasks) {
Map<Integer, MessageBuffer> messageBuffers = new HashMap<>();
for (Integer targetTask : targetTasks) {
messageBuffers.put(targetTask, new MessageBuffer(cacheSize));
}
targetTasksCache.put(sourceTask, messageBuffers);
}
}
@Override
public void send(List<TaskMessage> messages) {
// throw exception if the client is being closed
if (isClosed()) {
LOG.warn("Client is being closed, and does not take requests any more");
return;
}
long start = enableNettyMetrics && sendTimer != null ? sendTimer.getTime() : 0L;
try {
for (TaskMessage message : messages) {
waitforFlowCtrlAndSend(message);
}
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
if (sendTimer != null && enableNettyMetrics) {
sendTimer.updateTime(start);
}
}
}
@Override
public void send(TaskMessage message) {
// throw exception if the client is being closed
if (isClosed()) {
LOG.warn("Client is being closed, and does not take requests any more");
} else {
long start = enableNettyMetrics && sendTimer != null ? sendTimer.getTime() : 0L;
try {
waitforFlowCtrlAndSend(message);
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
if (sendTimer != null && enableNettyMetrics) {
sendTimer.updateTime(start);
}
}
}
}
@Override
public void sendDirect(TaskMessage message) {
synchronized (writeLock) {
Channel channel = waitForChannelReady();
if (channel != null)
flushRequest(channel, message);
}
}
void pushBatch(NettyMessage message) {
if (message == null || message.isEmpty()) {
return;
}
synchronized (writeLock) {
Channel channel = channelRef.get();
if (channel == null) {
messageBuffer.add(message, false);
LOG.debug("Pending requested message, the size is {}, because channel is not ready.", messageBuffer.size());
} else {
if (channel.isWritable()) {
MessageBatch messageBatch = messageBuffer.add(message);
if (messageBatch != null) {
flushRequest(channel, messageBatch);
}
} else {
messageBuffer.add(message, false);
}
}
if (messageBuffer.size() >= BATCH_THRESHOLD_WARN) {
channel = waitForChannelReady();
if (channel != null) {
MessageBatch messageBatch = messageBuffer.drain();
flushRequest(channel, messageBatch);
}
}
}
}
private boolean discardCheck(long pendingTime, long timeoutMs, int messageSize) {
if (timeoutMs != -1 && pendingTime >= timeoutMs) {
LOG.warn("Discard message due to pending message timeout({}ms), messageSize={}", timeoutMs, messageSize);
return true;
} else {
return false;
}
}
public Channel waitForChannelReady() {
Channel channel = channelRef.get();
long pendingTime = 0;
while ((channel == null && !isClosed()) || (channel != null && !channel.isWritable())) {
JStormUtils.sleepMs(1);
pendingTime++;
if (discardCheck(pendingTime, timeoutMs, messageBuffer.size())) {
messageBuffer.clear();
return null;
}
if (pendingTime % 30000 == 0) {
LOG.info("Pending total time={}, channel.isWritable={}, pendingNum={}, remoteAddress={}", pendingTime, channel != null ? channel.isWritable()
: null, pendings.get(), channel != null ? channel.getRemoteAddress() : null);
}
channel = channelRef.get();
}
return channel;
}
@Override
public void handleResponse(Channel channel, Object msg) {
if (msg == null) {
return;
}
TaskMessage message = (TaskMessage) msg;
short type = message.get_type();
if (type == TaskMessage.BACK_PRESSURE_REQUEST) {
byte[] messageData = message.message();
ByteBuffer buffer = ByteBuffer.allocate(Integer.SIZE + 1);
buffer.put(messageData);
buffer.flip();
boolean startFlowCtrl = buffer.get() == 1;
int targetTaskId = buffer.getInt();
//LOG.info("Received flow ctrl ({}) for target task-{}", startFlowCtrl, targetTaskId);
if (startFlowCtrl) {
addFlowControl(channel, targetTaskId);
} else {
Pair<Lock, Condition> pair = removeFlowControl(targetTaskId);
/*if (pair != null) {
try {
pair.getFirst().lock();
pair.getSecond().signalAll();
} finally {
pair.getFirst().unlock();
}
}*/
}
} else {
LOG.warn("Unexpected message (type={}) was received from task {}", type, message.task());
}
}
@Override
public String toString() {
return ToStringBuilder.reflectionToString(this, ToStringStyle.SHORT_PREFIX_STYLE);
}
private void addFlowControl(Channel channel, int taskId) {
targetTasksUnderFlowCtrl.put(taskId, true);
}
private Pair<Lock, Condition> removeFlowControl(int taskId) {
targetTasksUnderFlowCtrl.put(taskId, false);
return targetTasksToLocks.get(taskId);
}
private boolean isUnderFlowCtrl(int taskId) {
return targetTasksUnderFlowCtrl.get(taskId);
}
private MessageBuffer getCacheBuffer(int sourceTaskId, int targetTaskId) {
return targetTasksCache.get(sourceTaskId).get(targetTaskId);
}
private MessageBatch flushCacheBatch(int sourceTaskId, int targetTaskId) {
MessageBatch batch = null;
MessageBuffer buffer = getCacheBuffer(sourceTaskId, targetTaskId);
synchronized (buffer) {
batch = buffer.drain();
}
return batch;
}
private MessageBatch addMessageIntoCache(int sourceTaskId, int targetTaskId, TaskMessage message) {
MessageBatch batch = null;
MessageBuffer buffer = getCacheBuffer(sourceTaskId, targetTaskId);
synchronized (buffer) {
batch = buffer.add(message);
}
return batch;
}
private MessageBatch getPendingCaches() {
MessageBatch ret = new MessageBatch(cacheSize);
for (Entry<Integer, Map<Integer, MessageBuffer>> entry : targetTasksCache.entrySet()) {
int sourceTaskId = entry.getKey();
Map<Integer, MessageBuffer> MessageBuffers = entry.getValue();
for (Integer targetTaskId : MessageBuffers.keySet()) {
if (!isUnderFlowCtrl(targetTaskId)) {
MessageBatch batch = flushCacheBatch(sourceTaskId, targetTaskId);
if (batch != null)
ret.add(batch);
}
}
}
return ret;
}
private void waitforFlowCtrlAndSend(TaskMessage message) {
// If backpressure is disable, just send directly.
if (!isBackpressureEnable) {
pushBatch(message);
return;
}
int sourceTaskId = message.sourceTask();
int targetTaskId = message.task();
if (isUnderFlowCtrl(targetTaskId)) {
// If target task is under flow control
MessageBatch flushCache = addMessageIntoCache(sourceTaskId, targetTaskId, message);
if (flushCache != null) {
// Cache is full. Try to flush till flow control is released.
/*Pair<Lock, Condition> pair = targetTasksToLocks.get(targetTaskId);
long pendingTime = 0;
while (isUnderFlowCtrl(targetTaskId)) {
try {
pair.getFirst().lock();
if(pair.getSecond().await(flowCtrlAwaitTime, TimeUnit.MILLISECONDS))
break;
} catch (InterruptedException e) {
LOG.info("flow control was interrupted! targetTask-{}", targetTaskId);
} finally {
pair.getFirst().unlock();
}
pendingTime += flowCtrlAwaitTime;
if (discardCheck(pendingTime, timeoutMs, flushCache.getEncodedLength())) {
removeFlowControl(targetTaskId);
return;
}
if (pendingTime % 30000 == 0) {
LOG.info("Pending total time={} since target task-{} is under flow control ", pendingTime, targetTaskId);
}
}*/
long pendingTime = 0;
while (isUnderFlowCtrl(targetTaskId)) {
JStormUtils.sleepMs(1);
pendingTime++;
if (pendingTime % 30000 == 0) {
LOG.info("Pending total time={} since target task-{} is under flow control ", pendingTime, targetTaskId);
}
}
pushBatch(flushCache);
}
} else {
MessageBatch cache = flushCacheBatch(sourceTaskId, targetTaskId);
if (cache != null) {
cache.add(message);
pushBatch(cache);
} else {
pushBatch(message);
}
}
}
private void releaseFlowCtrlsForRemoteAddr(String remoteAddr) {
LOG.info("Release flow control for remoteAddr={}", remoteAddr);
for (Entry<Integer, Boolean> entry : targetTasksUnderFlowCtrl.entrySet()) {
entry.setValue(false);
}
}
@Override
public void disconnectChannel(Channel channel) {
releaseFlowCtrlsForRemoteAddr(channel.getRemoteAddress().toString());
if (isClosed()) {
return;
}
if (channel == channelRef.get()) {
setChannel(null);
reconnect();
} else {
closeChannel(channel);
}
}
@Override
public boolean available(int taskId) {
return super.available(taskId) && !isUnderFlowCtrl(taskId);
}
@Override
public void close() {
flusher.close();
super.close();
}
}