/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.jstorm.transactional.state;
import static org.slf4j.LoggerFactory.getLogger;
import java.io.Serializable;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.TreeSet;
import org.slf4j.Logger;
import backtype.storm.task.TopologyContext;
import backtype.storm.utils.Utils;
import com.alibaba.jstorm.client.ConfigExtension;
import com.alibaba.jstorm.transactional.TransactionCommon;
import com.alibaba.jstorm.transactional.state.TransactionState.State;
import com.alibaba.jstorm.transactional.state.task.ITaskStateInitOperator;
import com.alibaba.jstorm.utils.RotatingMap;
public class SnapshotState implements Serializable {
private static final long serialVersionUID = -4997799343186429338L;
private static final Logger LOG = getLogger(SnapshotState.class);
private transient TopologyContext context;
private transient Map conf;
private Map<Integer, String> taskToComponentId;
// componentId -> task set
private transient Map<String, Set<Integer>> sourceTasks;
private transient Map<String, Set<Integer>> statefulTasks;
private transient Map<String, Set<Integer>> nonStatefulTasks;
private transient Set<Integer> endTasks;
private BatchStateTracker lastSuccessfulSnapshot;
private transient RotatingMap<Long, BatchStateTracker> inprogressSnapshots;
private transient long prevLastSuccessfulBatchId = 0;
private State state;
private transient ITopologyStateOperator stateOperator;
// Map<ComponentId, TaskUserStateInitOperator>
private transient Map<String, ITaskStateInitOperator> taskStateInitOperators;
// Map<ComponentId, TaskSysStateInitOperator>
private transient Map<String, ITaskStateInitOperator> taskSysStateInitOperators;
public class BatchStateTracker implements Serializable {
private static final long serialVersionUID = -3503873401193374360L;
private long batchId;
private Map<String, Map<Integer, TransactionState>> spouts = new HashMap<String, Map<Integer, TransactionState>>();
private Map<String, Map<Integer, TransactionState>> statefulBolts = new HashMap<String, Map<Integer, TransactionState>>();
private transient Map<Integer, Boolean> endBolts = new HashMap<Integer, Boolean>();
private transient int receivedSpoutCount = 0;
private transient int receivedStatefulBoltCount = 0;
private transient int receivedEndBoltCount = 0;
private transient int expectedSpoutCount;
private transient int expectedStatefulBoltCount;
private transient int expectedEndBoltCount;
public BatchStateTracker(long batchId, Map<String, Set<Integer>> spouts, Map<String, Set<Integer>> statefulBolts, Set<Integer> endBolts) {
this.batchId = batchId;
for (Entry<String, Set<Integer>> entry : spouts.entrySet()) {
Map<Integer, TransactionState> states = new HashMap<Integer, TransactionState>();
for (Integer taskId : entry.getValue()) {
states.put(taskId, null);
}
expectedSpoutCount += states.size();
this.spouts.put(entry.getKey(), states);
}
for (Entry<String, Set<Integer>> entry : statefulBolts.entrySet()) {
Map<Integer, TransactionState> states = new HashMap<Integer, TransactionState>();
for (Integer taskId : entry.getValue()) {
states.put(taskId, null);
}
expectedStatefulBoltCount += states.size();
this.statefulBolts.put(entry.getKey(), states);
}
for (Integer taskId : endBolts) {
this.endBolts.put(taskId, false);
}
expectedEndBoltCount += endBolts.size();
}
public long getBatchId() {
return batchId;
}
public Map<String, Map<Integer, TransactionState>> getSpouts() {
return spouts;
}
public Map<String, Map<Integer, TransactionState>> getStatefulBolts() {
return statefulBolts;
}
private Map<Integer, TransactionState> flatComponentStates(Map<String, Map<Integer, TransactionState>> componentStats) {
Map<Integer, TransactionState> ret = new HashMap<>();
for (Map<Integer, TransactionState> states : componentStats.values()) {
ret.putAll(states);
}
return ret;
}
public Map<Integer, TransactionState> getSpoutStates() {
return flatComponentStates(spouts);
}
public Map<Integer, TransactionState> getStatefulBoltStates() {
return flatComponentStates(statefulBolts);
}
public Map<Integer, TransactionState> getComponentStates(String componentId) {
if (spouts.containsKey(componentId)) {
return spouts.get(componentId);
} else if (statefulBolts.containsKey(componentId)) {
return statefulBolts.get(componentId);
} else {
return null;
}
}
public TransactionState getStateByTaskId(int taskId) {
String componentId = taskToComponentId.get(taskId);
Map<Integer, TransactionState> componentStates = lastSuccessfulSnapshot.getComponentStates(componentId);
LOG.debug("taskId={}, componentId={}, states={}, componentStates={}", taskId, componentId, lastSuccessfulSnapshot.statesInfo(), componentStates);
return componentStates != null ? componentStates.get(taskId) : null;
}
public void updateSpout(int taskId, TransactionState state) {
Map<Integer, TransactionState> spoutStates = spouts.get(taskToComponentId.get(taskId));
if (spoutStates.put(taskId, state) != null) {
LOG.warn("Duplicated state commit for spout-{}, state={}", taskId, state);
} else {
receivedSpoutCount++;
}
}
public void updateStatefulBolt(int taskId, TransactionState state) {
Map<Integer, TransactionState> statefulBoltStates = statefulBolts.get(taskToComponentId.get(taskId));
if (statefulBoltStates.put(taskId, state) != null) {
LOG.warn("Duplicated state commit for statefulBolt-{}, state={}", taskId, state);
} else {
receivedStatefulBoltCount++;
}
}
public void updateEndBolt(int taskId) {
if (!endBolts.containsKey(taskId)) {
LOG.warn("Received unexpected task-{} when updating endBolts", taskId);
return;
}
if (endBolts.put(taskId, true)) {
LOG.warn("Duplicated ack for endBolt-{}", taskId);
} else {
receivedEndBoltCount++;
}
}
public boolean isFinished() {
if (state.equals(State.ACTIVE) && batchId != TransactionCommon.INIT_BATCH_ID) {
if (receivedSpoutCount == expectedSpoutCount && receivedStatefulBoltCount == expectedStatefulBoltCount
&& receivedEndBoltCount == expectedEndBoltCount) {
long expectedNextBatchId = getNextExpectedSuccessfulBatch();
if (batchId == expectedNextBatchId) {
return true;
} else {
LOG.info("Unexpected forward batch-{} was finished. But the expected is batch-{}", batchId, expectedNextBatchId);
return false;
}
} else {
return false;
}
} else {
// restart is done for rollback or initState
return receivedEndBoltCount == expectedEndBoltCount;
}
}
@Override
public String toString() {
return "batchId=" + batchId + ", receivedSpoutMsgCount=" + receivedSpoutCount + ", receivedStatefulBoltMsgCount=" + receivedStatefulBoltCount
+ ", receivedEndBoltMsgCount=" + receivedEndBoltCount;
}
public String statesInfo() {
return "[Spout States]: " + spouts + "\n" + "[Bolt States]: " + statefulBolts;
}
private Set<Integer> getNotCommittedTasks(Map<String, Map<Integer, TransactionState>> componentToStates) {
Set<Integer> notCommitTasks = new HashSet<Integer>();
for (Map<Integer, TransactionState> states : componentToStates.values()) {
for (Entry<Integer, TransactionState> entry : states.entrySet()) {
if (entry.getValue() == null) {
notCommitTasks.add(entry.getKey());
}
}
}
return notCommitTasks;
}
public String hasNotCommittedTasksInfo() {
Set<Integer> notCommitSpouts = getNotCommittedTasks(spouts);
Set<Integer> notCommitBolts = getNotCommittedTasks(statefulBolts);
Set<Integer> notCommitEndBolts = new HashSet<Integer>();
for (Entry<Integer, Boolean> entry : endBolts.entrySet()) {
if (!entry.getValue()) {
notCommitEndBolts.add(entry.getKey());
}
}
return "Not Committed tasks: Spouts=" + notCommitSpouts + ", Bolts=" + notCommitBolts + ", EndBolts=" + notCommitEndBolts;
}
public boolean isAnyCommittedTasks() {
return receivedSpoutCount != 0 || receivedStatefulBoltCount != 0 || receivedEndBoltCount != 0;
}
}
public SnapshotState() {
this.state = State.ACTIVE;
}
public SnapshotState(TopologyContext context, Map<String, Set<Integer>> spouts, Map<String, Set<Integer>> statefulBolts,
Map<String, Set<Integer>> nonStatefulBolts, Set<Integer> endBolts, ITopologyStateOperator stateOperator) {
this.context = context;
this.conf = context.getStormConf();
this.taskToComponentId = context.getTaskToComponent();
this.sourceTasks = spouts;
this.statefulTasks = statefulBolts;
this.nonStatefulTasks = nonStatefulBolts;
this.endTasks = endBolts;
this.lastSuccessfulSnapshot = new BatchStateTracker(TransactionCommon.INIT_BATCH_ID, spouts, statefulBolts, endBolts);
this.inprogressSnapshots = new RotatingMap<Long, BatchStateTracker>(3, true);
this.stateOperator = stateOperator;
this.taskStateInitOperators = new HashMap<String, ITaskStateInitOperator>();
LOG.info("");
Map<String, String> taskStateInitOpRegisterMap = ConfigExtension.getTransactionUserTaskInitRegisterMap(conf);
if (taskStateInitOpRegisterMap != null) {
for (Entry<String, String> entry : taskStateInitOpRegisterMap.entrySet()) {
taskStateInitOperators.put(entry.getKey(), (ITaskStateInitOperator) Utils.newInstance(entry.getValue()));
}
}
this.taskSysStateInitOperators = new HashMap<String, ITaskStateInitOperator>();
Map<String, String> taskSysStateInitOpRegisterMap = ConfigExtension.getTransactionSysTaskInitRegisterMap(conf);
if (taskSysStateInitOpRegisterMap != null) {
for (Entry<String, String> entry : taskSysStateInitOpRegisterMap.entrySet()) {
taskSysStateInitOperators.put(entry.getKey(), (ITaskStateInitOperator) Utils.newInstance(entry.getValue()));
}
}
this.state = State.ACTIVE;
}
private BatchStateTracker getStateTracker(long batchId) {
if (batchId != TransactionCommon.INIT_BATCH_ID) {
// If expired batch, just return null
if (batchId <= lastSuccessfulSnapshot.getBatchId()) {
LOG.warn("Received expired event for batchId-{}", batchId);
LOG.warn("Current inprogress snapshots: {}", inprogressSnapshots);
// In case there is still batch info in inprogress snapshots
inprogressSnapshots.remove(batchId);
return null;
}
}
BatchStateTracker stateTracker = inprogressSnapshots.get(batchId);
if (stateTracker == null) {
stateTracker = new BatchStateTracker(batchId, sourceTasks, statefulTasks, endTasks);
inprogressSnapshots.put(batchId, stateTracker);
}
return stateTracker;
}
/**
* @param batchId batch id
* @param taskId task id
* @param state state
* @return true if current batch is done
*/
public boolean commit(long batchId, int taskId, TransactionState state) {
// Ingore any batches' ack when current state is NOT active
if (batchId != TransactionCommon.INIT_BATCH_ID && !isActive())
return false;
String componentId = taskToComponentId.get(taskId);
if (sourceTasks.containsKey(componentId)) {
return commitSource(batchId, taskId, state);
} else if (statefulTasks.containsKey(componentId)) {
return commitStatefulBolt(batchId, taskId, state);
} else {
return false;
}
}
/**
* @param batchId batch id
* @param taskId task id
* @param state state
* @return true if current batch is done
*/
private boolean commitSource(long batchId, int taskId, TransactionState state) {
BatchStateTracker stateTracker = getStateTracker(batchId);
if (stateTracker != null) {
stateTracker.updateSpout(taskId, state);
return stateTracker.isFinished();
} else {
return false;
}
}
/**
* @param batchId batch id
* @param taskId task id
* @param state state
* @return true if current batch is done
*/
private boolean commitStatefulBolt(long batchId, int taskId, TransactionState state) {
BatchStateTracker stateTracker = getStateTracker(batchId);
if (stateTracker != null) {
stateTracker.updateStatefulBolt(taskId, state);
return stateTracker.isFinished();
} else {
return false;
}
}
/**
* @param batchId batch id
* @param taskId task id
* @return true if current batch is done
*/
public boolean ackEndBolt(long batchId, int taskId) {
// Ingore any batches' ack when current state is NOT active
if (batchId != TransactionCommon.INIT_BATCH_ID && !isActive())
return false;
BatchStateTracker stateTracker = getStateTracker(batchId);
if (stateTracker != null) {
stateTracker.updateEndBolt(taskId);
return stateTracker.isFinished();
} else {
return false;
}
}
/**
* @return last successful snapshot
*/
public Map<Integer, TransactionState> rollback() {
inprogressSnapshots.clear();
state = State.ROLLBACK;
getStateTracker(TransactionCommon.INIT_BATCH_ID);
Map<Integer, TransactionState> snapshotState = new HashMap<Integer, TransactionState>();
if (lastSuccessfulSnapshot.getBatchId() != TransactionCommon.INIT_BATCH_ID) {
snapshotState.putAll(lastSuccessfulSnapshot.getSpoutStates());
snapshotState.putAll(lastSuccessfulSnapshot.getStatefulBoltStates());
} else {
for (Set<Integer> spoutTasks : sourceTasks.values()) {
for (Integer taskId : spoutTasks) {
snapshotState.put(taskId, null);
}
}
for (Set<Integer> boltTasks : statefulTasks.values()) {
for (Integer taskId : boltTasks) {
snapshotState.put(taskId, null);
}
}
}
return snapshotState;
}
private TransactionState getRebuiltState(int taskId, BatchStateTracker lastSuccessfulSnapshot) {
Object systemState = null;
Object userState = null;
long batchId = lastSuccessfulSnapshot.getBatchId();
String componentId = taskToComponentId.get(taskId);
Map<Integer, TransactionState> componentStates = lastSuccessfulSnapshot.getComponentStates(componentId);
Set<Integer> currTasks = new HashSet<Integer>(context.getComponentTasks(componentId));
TransactionState state = componentStates != null ? componentStates.get(taskId) : null;
LOG.debug("States of component={}: {}", componentId, componentStates);
LOG.debug("taskStateInitOperators: {}", taskStateInitOperators);
LOG.debug("Prev state of task-{}: {}", taskId, state);
LOG.debug("currTasks: {}", currTasks);
// Get task's system state
ITaskStateInitOperator taskSysInitOperator = taskSysStateInitOperators.get(componentId);
if (taskSysInitOperator != null) {
taskSysInitOperator.getTaskInitState(conf, taskId, currTasks, componentStates);
} else {
if (stateOperator instanceof ITopologyStateInitOperator) {
ITopologyStateInitOperator richStateOperator = (ITopologyStateInitOperator) stateOperator;
if (sourceTasks.containsKey(componentId)) {
systemState = richStateOperator.getInitSpoutSysState(taskId, currTasks, componentStates);
} else if (statefulTasks.containsKey(componentId)) {
systemState = richStateOperator.getInitBoltSysState(taskId, currTasks, componentStates);
}
} else {
systemState = state != null ? state.systemCheckpoint : null;
}
}
// get task's user state
ITaskStateInitOperator taskInitOperator = taskStateInitOperators.get(componentId);
if (taskInitOperator != null) {
userState = taskInitOperator.getTaskInitState(conf, taskId, currTasks, componentStates);
} else {
if (stateOperator instanceof ITopologyStateInitOperator) {
ITopologyStateInitOperator richStateOperator = (ITopologyStateInitOperator) stateOperator;
if (sourceTasks.containsKey(componentId)) {
userState = richStateOperator.getInitSpoutUserState(taskId, sourceTasks.get(componentId), componentStates);
} else if (statefulTasks.containsKey(componentId)) {
userState = richStateOperator.getInitBoltUserState(taskId, statefulTasks.get(componentId), componentStates);
}
} else {
userState = state != null ? state.userCheckpoint : null;
}
}
TransactionState ret = new TransactionState(batchId, systemState, userState);
LOG.debug("Initial state={}", ret);
return ret;
}
private void rebuildComponentStates(Map<Integer, TransactionState> componentStates, BatchStateTracker lastSuccessfulSnapshot) {
for (Entry<Integer, TransactionState> entry : componentStates.entrySet()) {
Integer taskId = entry.getKey();
entry.setValue(getRebuiltState(taskId, lastSuccessfulSnapshot));
}
}
public TransactionState getInitState(int taskId) {
// If the task does not belong to this snapshot state, just return null;
if (!containsTask(taskId)) {
LOG.warn("Received unexpected init state request from task-{}", taskId);
return null;
}
/*TransactionState ret = null;
if (inprogressSnapshots.size() > 0)
ret = getLastestCommittedState(taskId);
if (ret == null)
ret = getLastSuccessState(taskId);
return ret;*/
return getLastSuccessState(taskId);
}
private TransactionState getLastestCommittedState(int taskId) {
TransactionState ret = null;
String componentId = context.getComponentId(taskId);
TreeSet<Long> batchIds = new TreeSet<Long>(inprogressSnapshots.keySet());
Long batchId = null;
while ((batchId = batchIds.pollLast()) != null) {
BatchStateTracker tracker = inprogressSnapshots.get(batchId);
Map<Integer, TransactionState> states = tracker.getComponentStates(componentId);
if (states != null && (ret = states.get(taskId)) != null)
break;
}
return ret;
}
private TransactionState getLastSuccessState(int taskId) {
TransactionState state = lastSuccessfulSnapshot.getStateByTaskId(taskId);
return state != null ? state : new TransactionState(lastSuccessfulSnapshot.batchId);
}
public void successBatch(long batchId) {
long nextBatchId = 1;
if (batchId == TransactionCommon.INIT_BATCH_ID) {
inprogressSnapshots.remove(batchId);
if (lastSuccessfulSnapshot != null)
nextBatchId = lastSuccessfulSnapshot.batchId + 1;
} else {
long id = lastSuccessfulSnapshot.getBatchId() + 1;
for (; id < batchId; id++) {
inprogressSnapshots.remove(id);
}
lastSuccessfulSnapshot = (BatchStateTracker) inprogressSnapshots.remove(batchId);
nextBatchId = batchId + 1;
}
// Prepare to track next batch
if (!inprogressSnapshots.containsKey(nextBatchId)) {
getStateTracker(nextBatchId);
}
}
private long getNextExpectedSuccessfulBatch() {
return lastSuccessfulSnapshot.getBatchId() + 1;
}
/**
* @return -1: none pending successful batch >1: batch id of next pending successful batch
*/
public long getPendingSuccessBatch() {
long nextBatchId = getNextExpectedSuccessfulBatch();
BatchStateTracker nextBatch = inprogressSnapshots.get(nextBatchId);
if (nextBatch == null) {
return -1;
} else {
if (nextBatch.isFinished()) {
return nextBatch.batchId;
} else {
return -1;
}
}
}
private Set<Integer> flatComponentsToTasks(Map<String, Set<Integer>> componentToTasks) {
Set<Integer> ret = new HashSet<Integer>();
for (Set<Integer> tasks : componentToTasks.values()) {
ret.addAll(tasks);
}
return ret;
}
private boolean containsTask(int taskId) {
String componentId = taskToComponentId.get(taskId);
if (sourceTasks.containsKey(componentId)) {
return true;
} else if (statefulTasks.containsKey(componentId)) {
return true;
} else if (nonStatefulTasks.containsKey(componentId)) {
return true;
} else {
return false;
}
}
public Set<Integer> getSpoutTasks() {
return flatComponentsToTasks(sourceTasks);
}
public Set<Integer> getStatefulTasks() {
return flatComponentsToTasks(statefulTasks);
}
public Set<Integer> getNonStatefulTasks() {
return flatComponentsToTasks(nonStatefulTasks);
}
/**
* @return Map[taskId, LastSuccessfulTransactionState]
*/
public Map<Integer, TransactionState> expiredCheck() {
Map<Integer, TransactionState> rollbackSnapshots = new HashMap<Integer, TransactionState>();
Map<Long, BatchStateTracker> expiredSnapshots = inprogressSnapshots.rotate();
if (expiredSnapshots.size() > 0) {
LOG.info("Found expired batch!");
for (BatchStateTracker tracker : expiredSnapshots.values()) {
LOG.info("{}, {}", tracker, tracker.hasNotCommittedTasksInfo());
}
rollbackSnapshots = rollback();
}
return rollbackSnapshots;
}
public boolean isActive() {
return state.equals(State.ACTIVE);
}
public boolean isRollback() {
return state.equals(State.ROLLBACK);
}
public void setActive() {
state = State.ACTIVE;
}
public State getState() {
return state;
}
public void setState(State state) {
this.state = state;
}
public void initState(SnapshotState state) {
long nextBatchId = 1;
if (state != null) {
/**
* rebuild last successful state tracker in case of any scaling-out/in
*/
BatchStateTracker oldLastSuccessStateTracker = state.getLastSuccessfulBatch();
BatchStateTracker newLastSuccessStateTracker = new BatchStateTracker(oldLastSuccessStateTracker.batchId, sourceTasks, statefulTasks, endTasks);
// rebuild spout states by old last successful states
Map<String, Map<Integer, TransactionState>> spouts = newLastSuccessStateTracker.getSpouts();
for (Map<Integer, TransactionState> componentStates : spouts.values()) {
rebuildComponentStates(componentStates, oldLastSuccessStateTracker);
}
// rebuild stateful bolt states by old last successful states
Map<String, Map<Integer, TransactionState>> statefulBolts = newLastSuccessStateTracker.getStatefulBolts();
for (Map<Integer, TransactionState> componentStates : statefulBolts.values()) {
rebuildComponentStates(componentStates, oldLastSuccessStateTracker);
}
setLastSuccessfulBatch(newLastSuccessStateTracker);
LOG.info("Old statesInfo: {}", oldLastSuccessStateTracker.statesInfo());
LOG.info("New statesInfo: {}", newLastSuccessStateTracker.statesInfo());
setState(state.getState());
nextBatchId = state.getLastSuccessfulBatch().batchId + 1;
}
// Move to next batch
getStateTracker(nextBatchId);
}
public BatchStateTracker getLastSuccessfulBatch() {
return lastSuccessfulSnapshot;
}
public void setLastSuccessfulBatch(BatchStateTracker tracker) {
this.lastSuccessfulSnapshot = tracker;
this.prevLastSuccessfulBatchId = tracker.getBatchId();
}
public long getLastSuccessfulBatchId() {
return lastSuccessfulSnapshot != null ? lastSuccessfulSnapshot.getBatchId() : 0;
}
public void setStateOperator(ITopologyStateOperator op) {
this.stateOperator = op;
}
public boolean isRunning() {
return prevLastSuccessfulBatchId != getLastSuccessfulBatchId();
}
@Override
public String toString() {
return "state=" + state.toString() + ", sourceTasks=" + sourceTasks + ", statefulTasks=" + statefulTasks + ", nonStatefulTasks" + nonStatefulTasks
+ ", endTasks=" + endTasks + ", inprogressSnapshots" + inprogressSnapshots + ", lastSuccessfulBatchId=" + lastSuccessfulSnapshot.batchId;
}
}