package com.alibaba.jstorm.transactional.spout;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.alibaba.jstorm.transactional.BatchGroupId;
import com.alibaba.jstorm.transactional.BatchSnapshot;
import com.alibaba.jstorm.transactional.TransactionCommon;
import com.alibaba.jstorm.utils.JStormUtils;
import backtype.storm.spout.SpoutOutputCollectorCb;
import backtype.storm.task.ICollectorCallback;
import backtype.storm.tuple.Values;
public class TransactionSpoutOutputCollector extends SpoutOutputCollectorCb {
public static Logger LOG = LoggerFactory.getLogger(TransactionSpoutOutputCollector.class);
SpoutOutputCollectorCb delegate;
private TransactionSpout spout;
private int groupId;
private long currBatchId;
private Map<Integer, Integer> msgCount;
private BatchInfo currBatchInfo;
private ReadWriteLock lock;
public static class BatchInfo {
public long batchId;
public Object endPos;
public BatchInfo() {
}
public BatchInfo(BatchInfo info) {
this.batchId = info.batchId;
this.endPos = info.endPos;
}
public void init(long batchId) {
this.batchId = batchId;
this.endPos = null;
}
}
private class CollectorCallback implements ICollectorCallback {
@Override
public void execute(String stream, List<Integer> outTasks, List values) {
for (Integer task : outTasks) {
int count = msgCount.get(task);
msgCount.put(task, ++count);
}
}
}
public TransactionSpoutOutputCollector(SpoutOutputCollectorCb delegate, TransactionSpout spout) {
this.delegate = delegate;
this.lock = new ReentrantReadWriteLock();
this.currBatchInfo = new BatchInfo();
this.spout = spout;
}
public void init(BatchGroupId id, Set<Integer> targetTasks) {
try {
lock.writeLock().lock();
setGroupId(id.groupId);
setCurrBatchId(id.batchId);
initMsgCount(targetTasks);
currBatchInfo.init(id.batchId);
} finally {
lock.writeLock().unlock();
}
}
public void initMsgCount(Set<Integer> targetTasks) {
this.msgCount = new HashMap<Integer, Integer>();
for (Integer task : targetTasks) {
this.msgCount.put(task, 0);
}
}
public void setGroupId(int groupId) {
this.groupId = groupId;
}
public int getGroupId() {
return groupId;
}
public void setCurrBatchId(long batchId) {
this.currBatchId = batchId;
}
public long getCurrBatchId() {
return currBatchId;
}
public void waitActive() {
while (spout.isActive() == false) {
JStormUtils.sleepMs(1);
}
}
@Override
public List<Integer> emit(String streamId, List<Object> tuple, Object messageId) {
return emit(streamId, tuple, messageId, new CollectorCallback());
}
@Override
public void emitDirect(int taskId, String streamId, List<Object> tuple, Object messageId) {
emitDirect(taskId, streamId, tuple, messageId, new CollectorCallback());
}
@Override
public List<Integer> emit(String streamId, List<Object> tuple, Object messageId, ICollectorCallback callback) {
try {
//waitActive();
lock.readLock().lock();
List<Object> tupleWithId = new ArrayList<Object>();
tupleWithId.add(new BatchGroupId(groupId, currBatchId));
tupleWithId.addAll(tuple);
delegate.emit(streamId, tupleWithId, null, (callback != null) ? callback : new CollectorCallback());
//currBatchInfo.endPos = messageId;
} finally {
lock.readLock().unlock();
}
return null;
}
@Override
public void emitDirect(int taskId, String streamId, List<Object> tuple, Object messageId, ICollectorCallback callback) {
try {
//waitActive();
lock.readLock().lock();
List<Object> tupleWithId = new ArrayList<Object>();
tupleWithId.add(new BatchGroupId(groupId, currBatchId));
tupleWithId.addAll(tuple);
delegate.emitDirect(taskId, streamId, tupleWithId, null, (callback != null) ? callback : new CollectorCallback());
//currBatchInfo.endPos = messageId;
} finally {
lock.readLock().unlock();
}
}
public List<Integer> emitByDelegate(String streamId, List<Object> tuple, Object messageId) {
return emitByDelegate(streamId, tuple, messageId, null);
}
public List<Integer> emitByDelegate(String streamId, List<Object> tuple, Object messageId, ICollectorCallback callback) {
return delegate.emit(streamId, tuple, messageId, callback);
}
public void emitDirectByDelegate(int taskId, String streamId, List<Object> tuple, Object messageId) {
emitDirectByDelegate(taskId, streamId, tuple, messageId, null);
}
public void emitDirectByDelegate(int taskId, String streamId, List<Object> tuple, Object messageId, ICollectorCallback callback) {
delegate.emitDirect(taskId, streamId, tuple, messageId, callback);
}
@Override
public void reportError(Throwable error) {
delegate.reportError(error);
}
public BatchInfo flushBarrier() {
BatchInfo ret = null;
try {
lock.writeLock().lock();
ret = new BatchInfo(currBatchInfo);
// flush pending message in outputCollector
delegate.flush();
// Emit and flush barrier message to all downstream tasks
BatchGroupId batchGroupId = new BatchGroupId(groupId, currBatchId);
for (Entry<Integer, Integer> entry : msgCount.entrySet()) {
int taskId = entry.getKey();
int count = entry.setValue(0);
BatchSnapshot barrierSnapshot = new BatchSnapshot(batchGroupId, count);
emitDirectByDelegate(taskId, TransactionCommon.BARRIER_STREAM_ID, new Values(batchGroupId, barrierSnapshot), null, null);
}
delegate.flush();
moveToNextBatch();
} finally {
lock.writeLock().unlock();
}
return ret;
}
public void moveToNextBatch() {
currBatchId++;
currBatchInfo.batchId = currBatchId;
}
public void flushInitBarrier() {
try {
lock.writeLock().lock();
// flush pending message in outputCollector
delegate.flush();
BatchGroupId batchGroupId = new BatchGroupId(groupId, TransactionCommon.INIT_BATCH_ID);
BatchSnapshot barrierSnapshot = new BatchSnapshot(batchGroupId, 0);
for (Entry<Integer, Integer> entry : msgCount.entrySet()) {
entry.setValue(0);
emitDirectByDelegate(entry.getKey(), TransactionCommon.BARRIER_STREAM_ID, new Values(batchGroupId, barrierSnapshot), null, null);
}
delegate.flush();
} finally {
lock.writeLock().unlock();
}
}
}