/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package backtype.storm.coordination; import static backtype.storm.utils.Utils.get; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Map.Entry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import backtype.storm.Constants; import backtype.storm.generated.GlobalStreamId; import backtype.storm.generated.Grouping; import backtype.storm.task.ICollectorCallback; import backtype.storm.task.IOutputCollector; import backtype.storm.task.OutputCollector; import backtype.storm.task.TopologyContext; import backtype.storm.topology.FailedException; import backtype.storm.topology.IRichBolt; import backtype.storm.topology.OutputFieldsDeclarer; import backtype.storm.tuple.Fields; import backtype.storm.tuple.Tuple; import backtype.storm.tuple.Values; import backtype.storm.utils.TimeCacheMap; import backtype.storm.utils.Utils; /** * Coordination requires the request ids to be globally unique for a while * so that it doesn't get confused in the case of retries. */ public class CoordinatedBolt implements IRichBolt { public static Logger LOG = LoggerFactory.getLogger(CoordinatedBolt.class); public interface FinishedCallback { void finishedId(Object id); } public interface TimeoutCallback { void timeoutId(Object id); } public static class SourceArgs implements Serializable { public boolean singleCount; protected SourceArgs(boolean singleCount) { this.singleCount = singleCount; } public static SourceArgs single() { return new SourceArgs(true); } public static SourceArgs all() { return new SourceArgs(false); } @Override public String toString() { return "<Single: " + singleCount + ">"; } } public class CoordinatedOutputCollector implements IOutputCollector { OutputCollector _delegate; public CoordinatedOutputCollector(OutputCollector delegate) { _delegate = delegate; } public List<Integer> emit(String stream, Collection<Tuple> anchors, List<Object> tuple) { return _delegate.emit(stream, anchors, tuple, new CollectorCb(tuple.get(0))); } public void emitDirect(int task, String stream, Collection<Tuple> anchors, List<Object> tuple) { updateTaskCounts(tuple.get(0), Arrays.asList(task)); _delegate.emitDirect(task, stream, anchors, tuple); // due to here do updateTaskCounts, so we do flush operation _delegate.flush(); } public void ack(Tuple tuple) { Object id = tuple.getValue(0); synchronized (_tracked) { TrackingInfo track = _tracked.get(id); if (track != null) track.receivedTuples++; } boolean failed = checkFinishId(tuple, TupleType.REGULAR); if (failed) { _delegate.fail(tuple); } else { _delegate.ack(tuple); } } public void fail(Tuple tuple) { Object id = tuple.getValue(0); synchronized (_tracked) { TrackingInfo track = _tracked.get(id); if (track != null) track.failed = true; } checkFinishId(tuple, TupleType.REGULAR); _delegate.fail(tuple); } public void reportError(Throwable error) { _delegate.reportError(error); } private void updateTaskCounts(Object id, List<Integer> tasks) { synchronized (_tracked) { TrackingInfo track = _tracked.get(id); if (track != null) { Map<Integer, Integer> taskEmittedTuples = track.taskEmittedTuples; for (Integer task : tasks) { int newCount = get(taskEmittedTuples, task, 0) + 1; taskEmittedTuples.put(task, newCount); } } } } class CollectorCb implements ICollectorCallback { Object id; public CollectorCb(Object id) { this.id = id; } @Override public void execute(String stream, List<Integer> outTasks, List values) { updateTaskCounts(id, outTasks); } } } private Map<String, SourceArgs> _sourceArgs; private IdStreamSpec _idStreamSpec; private IRichBolt _delegate; private Integer _numSourceReports; private List<Integer> _countOutTasks = new ArrayList<>(); private OutputCollector _collector; private TimeCacheMap<Object, TrackingInfo> _tracked; public static class TrackingInfo { int reportCount = 0; int expectedTupleCount = 0; int receivedTuples = 0; boolean failed = false; Map<Integer, Integer> taskEmittedTuples = new HashMap<>(); boolean receivedId = false; boolean finished = false; List<Tuple> ackTuples = new ArrayList<>(); @Override public String toString() { return "reportCount: " + reportCount + "\n" + "expectedTupleCount: " + expectedTupleCount + "\n" + "receivedTuples: " + receivedTuples + "\n" + "failed: " + failed + "\n" + taskEmittedTuples.toString(); } } public static class IdStreamSpec implements Serializable { GlobalStreamId _id; public GlobalStreamId getGlobalStreamId() { return _id; } public static IdStreamSpec makeDetectSpec(String component, String stream) { return new IdStreamSpec(component, stream); } protected IdStreamSpec(String component, String stream) { _id = new GlobalStreamId(component, stream); } } public CoordinatedBolt(IRichBolt delegate) { this(delegate, null, null); } public CoordinatedBolt(IRichBolt delegate, String sourceComponent, SourceArgs sourceArgs, IdStreamSpec idStreamSpec) { this(delegate, singleSourceArgs(sourceComponent, sourceArgs), idStreamSpec); } public CoordinatedBolt(IRichBolt delegate, Map<String, SourceArgs> sourceArgs, IdStreamSpec idStreamSpec) { _sourceArgs = sourceArgs; if (_sourceArgs == null) _sourceArgs = new HashMap<>(); _delegate = delegate; _idStreamSpec = idStreamSpec; } public void prepare(Map config, TopologyContext context, OutputCollector collector) { TimeCacheMap.ExpiredCallback<Object, TrackingInfo> callback = null; if (_delegate instanceof TimeoutCallback) { callback = new TimeoutItems(); } _tracked = new TimeCacheMap<>(context.maxTopologyMessageTimeout(), callback); _collector = collector; _delegate.prepare(config, context, new OutputCollector(new CoordinatedOutputCollector(collector))); for (String component : Utils.get(context.getThisTargets(), Constants.COORDINATED_STREAM_ID, new HashMap<String, Grouping>()).keySet()) { for (Integer task : context.getComponentTasks(component)) { _countOutTasks.add(task); } } if (!_sourceArgs.isEmpty()) { _numSourceReports = 0; for (Entry<String, SourceArgs> entry : _sourceArgs.entrySet()) { if (entry.getValue().singleCount) { _numSourceReports += 1; } else { _numSourceReports += context.getComponentTasks(entry.getKey()).size(); } } } } private boolean checkFinishId(Tuple tup, TupleType type) { _collector.flush(); Object id = tup.getValue(0); boolean failed = false; synchronized (_tracked) { TrackingInfo track = _tracked.get(id); try { if (track != null) { boolean delayed = false; if (_idStreamSpec == null && type == TupleType.COORD || _idStreamSpec != null && type == TupleType.ID) { track.ackTuples.add(tup); delayed = true; } if (track.failed) { failed = true; for (Tuple t : track.ackTuples) { _collector.fail(t); } _tracked.remove(id); } else if (track.receivedId && (_sourceArgs.isEmpty() || track.reportCount == _numSourceReports && track.expectedTupleCount == track.receivedTuples)) { if (_delegate instanceof FinishedCallback) { ((FinishedCallback) _delegate).finishedId(id); } if (!(_sourceArgs.isEmpty() || type != TupleType.REGULAR)) { throw new IllegalStateException("Coordination condition met with a non-coordinating tuple. Should be impossible"); } Iterator<Integer> outTasks = _countOutTasks.iterator(); _collector.flush(); while (outTasks.hasNext()) { int task = outTasks.next(); int numTuples = get(track.taskEmittedTuples, task, 0); _collector.emitDirect(task, Constants.COORDINATED_STREAM_ID, tup, new Values(id, numTuples)); } for (Tuple t : track.ackTuples) { _collector.ack(t); } track.finished = true; _tracked.remove(id); } if (!delayed && type != TupleType.REGULAR) { if (track.failed) { _collector.fail(tup); } else { _collector.ack(tup); } } } else { if (type != TupleType.REGULAR) _collector.fail(tup); } } catch (FailedException e) { LOG.error("Failed to finish batch", e); for (Tuple t : track.ackTuples) { _collector.fail(t); } _tracked.remove(id); failed = true; } finally { _collector.flush(); } } return failed; } public void execute(Tuple tuple) { Object id = tuple.getValue(0); TrackingInfo track; TupleType type = getTupleType(tuple); synchronized (_tracked) { track = _tracked.get(id); if (track == null) { track = new TrackingInfo(); if (_idStreamSpec == null) track.receivedId = true; _tracked.put(id, track); } } if (type == TupleType.ID) { synchronized (_tracked) { track.receivedId = true; } checkFinishId(tuple, type); } else if (type == TupleType.COORD) { int count = (Integer) tuple.getValue(1); synchronized (_tracked) { track.reportCount++; track.expectedTupleCount += count; } checkFinishId(tuple, type); } else { synchronized (_tracked) { _delegate.execute(tuple); } } } public void cleanup() { _delegate.cleanup(); _tracked.cleanup(); } public void declareOutputFields(OutputFieldsDeclarer declarer) { _delegate.declareOutputFields(declarer); declarer.declareStream(Constants.COORDINATED_STREAM_ID, true, new Fields("id", "count")); } @Override public Map<String, Object> getComponentConfiguration() { return _delegate.getComponentConfiguration(); } private static Map<String, SourceArgs> singleSourceArgs(String sourceComponent, SourceArgs sourceArgs) { Map<String, SourceArgs> ret = new HashMap<>(); ret.put(sourceComponent, sourceArgs); return ret; } private class TimeoutItems implements TimeCacheMap.ExpiredCallback<Object, TrackingInfo> { @Override public void expire(Object id, TrackingInfo val) { synchronized (_tracked) { // the combination of the lock and the finished flag ensure that // an id is never timed out if it has been finished val.failed = true; if (!val.finished) { ((TimeoutCallback) _delegate).timeoutId(id); } } } } private TupleType getTupleType(Tuple tuple) { if (_idStreamSpec != null && tuple.getSourceGlobalStreamid().equals(_idStreamSpec._id)) { return TupleType.ID; } else if (!_sourceArgs.isEmpty() && tuple.getSourceStreamId().equals(Constants.COORDINATED_STREAM_ID)) { return TupleType.COORD; } else { return TupleType.REGULAR; } } enum TupleType { REGULAR, ID, COORD } }