/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.checkpoint;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.util.Preconditions;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Current default implementation of {@link OperatorStateRepartitioner} that redistributes state in round robin fashion.
*/
public class RoundRobinOperatorStateRepartitioner implements OperatorStateRepartitioner {
public static final OperatorStateRepartitioner INSTANCE = new RoundRobinOperatorStateRepartitioner();
private static final boolean OPTIMIZE_MEMORY_USE = false;
@Override
public List<Collection<OperatorStateHandle>> repartitionState(
List<OperatorStateHandle> previousParallelSubtaskStates,
int parallelism) {
Preconditions.checkNotNull(previousParallelSubtaskStates);
Preconditions.checkArgument(parallelism > 0);
// Reorganize: group by (State Name -> StreamStateHandle + Offsets)
GroupByStateNameResults nameToStateByMode = groupByStateName(previousParallelSubtaskStates);
if (OPTIMIZE_MEMORY_USE) {
previousParallelSubtaskStates.clear(); // free for GC at to cost that old handles are no longer available
}
// Assemble result from all merge maps
List<Collection<OperatorStateHandle>> result = new ArrayList<>(parallelism);
// Do the actual repartitioning for all named states
List<Map<StreamStateHandle, OperatorStateHandle>> mergeMapList =
repartition(nameToStateByMode, parallelism);
for (int i = 0; i < mergeMapList.size(); ++i) {
result.add(i, new ArrayList<>(mergeMapList.get(i).values()));
}
return result;
}
/**
* Group by the different named states.
*/
@SuppressWarnings("unchecked, rawtype")
private GroupByStateNameResults groupByStateName(
List<OperatorStateHandle> previousParallelSubtaskStates) {
//Reorganize: group by (State Name -> StreamStateHandle + StateMetaInfo)
EnumMap<OperatorStateHandle.Mode,
Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>> nameToStateByMode =
new EnumMap<>(OperatorStateHandle.Mode.class);
for (OperatorStateHandle.Mode mode : OperatorStateHandle.Mode.values()) {
Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> map = new HashMap<>();
nameToStateByMode.put(
mode,
new HashMap<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>());
}
for (OperatorStateHandle psh : previousParallelSubtaskStates) {
for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> e :
psh.getStateNameToPartitionOffsets().entrySet()) {
OperatorStateHandle.StateMetaInfo metaInfo = e.getValue();
Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> nameToState =
nameToStateByMode.get(metaInfo.getDistributionMode());
List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> stateLocations =
nameToState.get(e.getKey());
if (stateLocations == null) {
stateLocations = new ArrayList<>();
nameToState.put(e.getKey(), stateLocations);
}
stateLocations.add(new Tuple2<>(psh.getDelegateStateHandle(), e.getValue()));
}
}
return new GroupByStateNameResults(nameToStateByMode);
}
/**
* Repartition all named states.
*/
private List<Map<StreamStateHandle, OperatorStateHandle>> repartition(
GroupByStateNameResults nameToStateByMode,
int parallelism) {
// We will use this to merge w.r.t. StreamStateHandles for each parallel subtask inside the maps
List<Map<StreamStateHandle, OperatorStateHandle>> mergeMapList = new ArrayList<>(parallelism);
// Initialize
for (int i = 0; i < parallelism; ++i) {
mergeMapList.add(new HashMap<StreamStateHandle, OperatorStateHandle>());
}
// Start with the state handles we distribute round robin by splitting by offsets
Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> distributeNameToState =
nameToStateByMode.getByMode(OperatorStateHandle.Mode.SPLIT_DISTRIBUTE);
int startParallelOp = 0;
// Iterate all named states and repartition one named state at a time per iteration
for (Map.Entry<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> e :
distributeNameToState.entrySet()) {
List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> current = e.getValue();
// Determine actual number of partitions for this named state
int totalPartitions = 0;
for (Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> offsets : current) {
totalPartitions += offsets.f1.getOffsets().length;
}
// Repartition the state across the parallel operator instances
int lstIdx = 0;
int offsetIdx = 0;
int baseFraction = totalPartitions / parallelism;
int remainder = totalPartitions % parallelism;
int newStartParallelOp = startParallelOp;
for (int i = 0; i < parallelism; ++i) {
// Preparation: calculate the actual index considering wrap around
int parallelOpIdx = (i + startParallelOp) % parallelism;
// Now calculate the number of partitions we will assign to the parallel instance in this round ...
int numberOfPartitionsToAssign = baseFraction;
// ... and distribute odd partitions while we still have some, one at a time
if (remainder > 0) {
++numberOfPartitionsToAssign;
--remainder;
} else if (remainder == 0) {
// We are out of odd partitions now and begin our next redistribution round with the current
// parallel operator to ensure fair load balance
newStartParallelOp = parallelOpIdx;
--remainder;
}
// Now start collection the partitions for the parallel instance into this list
List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> parallelOperatorState =
new ArrayList<>();
while (numberOfPartitionsToAssign > 0) {
Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> handleWithOffsets =
current.get(lstIdx);
long[] offsets = handleWithOffsets.f1.getOffsets();
int remaining = offsets.length - offsetIdx;
// Repartition offsets
long[] offs;
if (remaining > numberOfPartitionsToAssign) {
offs = Arrays.copyOfRange(offsets, offsetIdx, offsetIdx + numberOfPartitionsToAssign);
offsetIdx += numberOfPartitionsToAssign;
} else {
if (OPTIMIZE_MEMORY_USE) {
handleWithOffsets.f1 = null; // GC
}
offs = Arrays.copyOfRange(offsets, offsetIdx, offsets.length);
offsetIdx = 0;
++lstIdx;
}
parallelOperatorState.add(new Tuple2<>(
handleWithOffsets.f0,
new OperatorStateHandle.StateMetaInfo(offs, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)));
numberOfPartitionsToAssign -= remaining;
// As a last step we merge partitions that use the same StreamStateHandle in a single
// OperatorStateHandle
Map<StreamStateHandle, OperatorStateHandle> mergeMap = mergeMapList.get(parallelOpIdx);
OperatorStateHandle operatorStateHandle = mergeMap.get(handleWithOffsets.f0);
if (operatorStateHandle == null) {
operatorStateHandle = new OperatorStateHandle(
new HashMap<String, OperatorStateHandle.StateMetaInfo>(),
handleWithOffsets.f0);
mergeMap.put(handleWithOffsets.f0, operatorStateHandle);
}
operatorStateHandle.getStateNameToPartitionOffsets().put(
e.getKey(),
new OperatorStateHandle.StateMetaInfo(offs, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
}
}
startParallelOp = newStartParallelOp;
e.setValue(null);
}
// Now we also add the state handles marked for broadcast to all parallel instances
Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> broadcastNameToState =
nameToStateByMode.getByMode(OperatorStateHandle.Mode.BROADCAST);
for (int i = 0; i < parallelism; ++i) {
Map<StreamStateHandle, OperatorStateHandle> mergeMap = mergeMapList.get(i);
for (Map.Entry<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> e :
broadcastNameToState.entrySet()) {
List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> current = e.getValue();
for (Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> handleWithMetaInfo : current) {
OperatorStateHandle operatorStateHandle = mergeMap.get(handleWithMetaInfo.f0);
if (operatorStateHandle == null) {
operatorStateHandle = new OperatorStateHandle(
new HashMap<String, OperatorStateHandle.StateMetaInfo>(),
handleWithMetaInfo.f0);
mergeMap.put(handleWithMetaInfo.f0, operatorStateHandle);
}
operatorStateHandle.getStateNameToPartitionOffsets().put(e.getKey(), handleWithMetaInfo.f1);
}
}
}
return mergeMapList;
}
private static final class GroupByStateNameResults {
private final EnumMap<OperatorStateHandle.Mode,
Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>> byMode;
public GroupByStateNameResults(
EnumMap<OperatorStateHandle.Mode,
Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>> byMode) {
this.byMode = Preconditions.checkNotNull(byMode);
}
public Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> getByMode(
OperatorStateHandle.Mode mode) {
return byMode.get(mode);
}
}
}