/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.flink.translation.wrappers.streaming.state;
import static org.apache.flink.util.Preconditions.checkArgument;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.beam.runners.core.StateInternals;
import org.apache.beam.runners.core.StateNamespace;
import org.apache.beam.runners.core.StateTag;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.ListCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.CombiningState;
import org.apache.beam.sdk.state.MapState;
import org.apache.beam.sdk.state.ReadableState;
import org.apache.beam.sdk.state.SetState;
import org.apache.beam.sdk.state.State;
import org.apache.beam.sdk.state.StateContext;
import org.apache.beam.sdk.state.StateContexts;
import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.state.WatermarkHoldState;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.CombineWithContext;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.state.KeyGroupsList;
import org.apache.flink.runtime.state.KeyedStateBackend;
import org.apache.flink.streaming.api.operators.HeapInternalTimerService;
import org.apache.flink.util.InstantiationUtil;
import org.apache.flink.util.Preconditions;
/**
* {@link StateInternals} that uses {@link KeyGroupCheckpointedOperator}
* to checkpoint state.
*
* <p>Note:
* Ignore index of key.
* Just implement BagState.
*
* <p>Reference from {@link HeapInternalTimerService} to the local key-group range.
*/
public class FlinkKeyGroupStateInternals<K> implements StateInternals {
private final Coder<K> keyCoder;
private final KeyGroupsList localKeyGroupRange;
private KeyedStateBackend keyedStateBackend;
private final int localKeyGroupRangeStartIdx;
// stateName -> namespace -> (valueCoder, value)
private final Map<String, Tuple2<Coder<?>, Map<String, ?>>>[] stateTables;
public FlinkKeyGroupStateInternals(
Coder<K> keyCoder,
KeyedStateBackend keyedStateBackend) {
this.keyCoder = keyCoder;
this.keyedStateBackend = keyedStateBackend;
this.localKeyGroupRange = keyedStateBackend.getKeyGroupRange();
// find the starting index of the local key-group range
int startIdx = Integer.MAX_VALUE;
for (Integer keyGroupIdx : localKeyGroupRange) {
startIdx = Math.min(keyGroupIdx, startIdx);
}
this.localKeyGroupRangeStartIdx = startIdx;
stateTables = (Map<String, Tuple2<Coder<?>, Map<String, ?>>>[])
new Map[localKeyGroupRange.getNumberOfKeyGroups()];
for (int i = 0; i < stateTables.length; i++) {
stateTables[i] = new HashMap<>();
}
}
@Override
public K getKey() {
ByteBuffer keyBytes = (ByteBuffer) keyedStateBackend.getCurrentKey();
try {
return CoderUtils.decodeFromByteArray(keyCoder, keyBytes.array());
} catch (CoderException e) {
throw new RuntimeException("Error decoding key.", e);
}
}
@Override
public <T extends State> T state(
final StateNamespace namespace,
StateTag<T> address) {
return state(namespace, address, StateContexts.nullContext());
}
@Override
public <T extends State> T state(
final StateNamespace namespace,
StateTag<T> address,
final StateContext<?> context) {
return address.bind(
new StateTag.StateBinder() {
@Override
public <T> ValueState<T> bindValue(
StateTag<ValueState<T>> address, Coder<T> coder) {
throw new UnsupportedOperationException(
String.format("%s is not supported", ValueState.class.getSimpleName()));
}
@Override
public <T> BagState<T> bindBag(
StateTag<BagState<T>> address, Coder<T> elemCoder) {
return new FlinkKeyGroupBagState<>(address, namespace, elemCoder);
}
@Override
public <T> SetState<T> bindSet(
StateTag<SetState<T>> address, Coder<T> elemCoder) {
throw new UnsupportedOperationException(
String.format("%s is not supported", SetState.class.getSimpleName()));
}
@Override
public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
StateTag<MapState<KeyT, ValueT>> spec,
Coder<KeyT> mapKeyCoder,
Coder<ValueT> mapValueCoder) {
throw new UnsupportedOperationException(
String.format("%s is not supported", MapState.class.getSimpleName()));
}
@Override
public <InputT, AccumT, OutputT>
CombiningState<InputT, AccumT, OutputT> bindCombiningValue(
StateTag<CombiningState<InputT, AccumT, OutputT>> address,
Coder<AccumT> accumCoder,
Combine.CombineFn<InputT, AccumT, OutputT> combineFn) {
throw new UnsupportedOperationException("bindCombiningValue is not supported.");
}
@Override
public <InputT, AccumT, OutputT>
CombiningState<InputT, AccumT, OutputT> bindCombiningValueWithContext(
StateTag<CombiningState<InputT, AccumT, OutputT>> address,
Coder<AccumT> accumCoder,
CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn) {
throw new UnsupportedOperationException(
"bindCombiningValueWithContext is not supported.");
}
@Override
public WatermarkHoldState bindWatermark(
StateTag<WatermarkHoldState> address,
TimestampCombiner timestampCombiner) {
throw new UnsupportedOperationException(
String.format("%s is not supported", CombiningState.class.getSimpleName()));
}
});
}
/**
* Reference from {@link Combine.CombineFn}.
*
* <p>Accumulators are stored in each KeyGroup, call addInput() when a element comes,
* call extractOutput() to produce the desired value when need to read data.
*/
interface KeyGroupCombiner<InputT, AccumT, OutputT> {
/**
* Returns a new, mutable accumulator value, representing the accumulation
* of zero input values.
*/
AccumT createAccumulator();
/**
* Adds the given input value to the given accumulator, returning the
* new accumulator value.
*/
AccumT addInput(AccumT accumulator, InputT input);
/**
* Returns the output value that is the result of all accumulators from KeyGroups
* that are assigned to this operator.
*/
OutputT extractOutput(Iterable<AccumT> accumulators);
}
private abstract class AbstractKeyGroupState<InputT, AccumT, OutputT> {
private String stateName;
private String namespace;
private Coder<AccumT> coder;
private KeyGroupCombiner<InputT, AccumT, OutputT> keyGroupCombiner;
AbstractKeyGroupState(
String stateName,
String namespace,
Coder<AccumT> coder,
KeyGroupCombiner<InputT, AccumT, OutputT> keyGroupCombiner) {
this.stateName = stateName;
this.namespace = namespace;
this.coder = coder;
this.keyGroupCombiner = keyGroupCombiner;
}
/**
* Choose keyGroup of input and addInput to accumulator.
*/
void addInput(InputT input) {
int keyGroupIdx = keyedStateBackend.getCurrentKeyGroupIndex();
int localIdx = getIndexForKeyGroup(keyGroupIdx);
Map<String, Tuple2<Coder<?>, Map<String, ?>>> stateTable = stateTables[localIdx];
Tuple2<Coder<?>, Map<String, ?>> tuple2 = stateTable.get(stateName);
if (tuple2 == null) {
tuple2 = new Tuple2<>();
tuple2.f0 = coder;
tuple2.f1 = new HashMap<>();
stateTable.put(stateName, tuple2);
}
Map<String, AccumT> map = (Map<String, AccumT>) tuple2.f1;
AccumT accumulator = map.get(namespace);
if (accumulator == null) {
accumulator = keyGroupCombiner.createAccumulator();
}
accumulator = keyGroupCombiner.addInput(accumulator, input);
map.put(namespace, accumulator);
}
/**
* Get all accumulators and invoke extractOutput().
*/
OutputT extractOutput() {
List<AccumT> accumulators = new ArrayList<>(stateTables.length);
for (Map<String, Tuple2<Coder<?>, Map<String, ?>>> stateTable : stateTables) {
Tuple2<Coder<?>, Map<String, ?>> tuple2 = stateTable.get(stateName);
if (tuple2 != null) {
AccumT accumulator = (AccumT) tuple2.f1.get(namespace);
if (accumulator != null) {
accumulators.add(accumulator);
}
}
}
return keyGroupCombiner.extractOutput(accumulators);
}
/**
* Find the first accumulator and return immediately.
*/
boolean isEmptyInternal() {
for (Map<String, Tuple2<Coder<?>, Map<String, ?>>> stateTable : stateTables) {
Tuple2<Coder<?>, Map<String, ?>> tuple2 = stateTable.get(stateName);
if (tuple2 != null) {
AccumT accumulator = (AccumT) tuple2.f1.get(namespace);
if (accumulator != null) {
return false;
}
}
}
return true;
}
/**
* Clear accumulators and clean empty map.
*/
void clearInternal() {
for (Map<String, Tuple2<Coder<?>, Map<String, ?>>> stateTable : stateTables) {
Tuple2<Coder<?>, Map<String, ?>> tuple2 = stateTable.get(stateName);
if (tuple2 != null) {
tuple2.f1.remove(namespace);
if (tuple2.f1.size() == 0) {
stateTable.remove(stateName);
}
}
}
}
}
private int getIndexForKeyGroup(int keyGroupIdx) {
checkArgument(localKeyGroupRange.contains(keyGroupIdx),
"Key Group " + keyGroupIdx + " does not belong to the local range.");
return keyGroupIdx - this.localKeyGroupRangeStartIdx;
}
private class KeyGroupBagCombiner<T> implements KeyGroupCombiner<T, List<T>, Iterable<T>> {
@Override
public List<T> createAccumulator() {
return new ArrayList<>();
}
@Override
public List<T> addInput(List<T> accumulator, T input) {
accumulator.add(input);
return accumulator;
}
@Override
public Iterable<T> extractOutput(Iterable<List<T>> accumulators) {
List<T> result = new ArrayList<>();
// maybe can return an unmodifiable view.
for (List<T> list : accumulators) {
result.addAll(list);
}
return result;
}
}
private class FlinkKeyGroupBagState<T> extends AbstractKeyGroupState<T, List<T>, Iterable<T>>
implements BagState<T> {
private final StateNamespace namespace;
private final StateTag<BagState<T>> address;
FlinkKeyGroupBagState(
StateTag<BagState<T>> address,
StateNamespace namespace,
Coder<T> coder) {
super(address.getId(), namespace.stringKey(), ListCoder.of(coder),
new KeyGroupBagCombiner<T>());
this.namespace = namespace;
this.address = address;
}
@Override
public void add(T input) {
addInput(input);
}
@Override
public BagState<T> readLater() {
return this;
}
@Override
public Iterable<T> read() {
Iterable<T> result = extractOutput();
return result != null ? result : Collections.<T>emptyList();
}
@Override
public ReadableState<Boolean> isEmpty() {
return new ReadableState<Boolean>() {
@Override
public Boolean read() {
try {
return isEmptyInternal();
} catch (Exception e) {
throw new RuntimeException("Error reading state.", e);
}
}
@Override
public ReadableState<Boolean> readLater() {
return this;
}
};
}
@Override
public void clear() {
clearInternal();
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
FlinkKeyGroupBagState<?> that = (FlinkKeyGroupBagState<?>) o;
return namespace.equals(that.namespace) && address.equals(that.address);
}
@Override
public int hashCode() {
int result = namespace.hashCode();
result = 31 * result + address.hashCode();
return result;
}
}
/**
* Snapshots the state {@code (stateName -> (valueCoder && (namespace -> value)))} for a given
* {@code keyGroupIdx}.
*
* @param keyGroupIdx the id of the key-group to be put in the snapshot.
* @param out the stream to write to.
*/
public void snapshotKeyGroupState(int keyGroupIdx, DataOutputStream out) throws Exception {
int localIdx = getIndexForKeyGroup(keyGroupIdx);
Map<String, Tuple2<Coder<?>, Map<String, ?>>> stateTable = stateTables[localIdx];
Preconditions.checkState(stateTable.size() <= Short.MAX_VALUE,
"Too many States: " + stateTable.size() + ". Currently at most "
+ Short.MAX_VALUE + " states are supported");
out.writeShort(stateTable.size());
for (Map.Entry<String, Tuple2<Coder<?>, Map<String, ?>>> entry : stateTable.entrySet()) {
out.writeUTF(entry.getKey());
Coder coder = entry.getValue().f0;
InstantiationUtil.serializeObject(out, coder);
Map<String, ?> map = entry.getValue().f1;
out.writeInt(map.size());
for (Map.Entry<String, ?> entry1 : map.entrySet()) {
StringUtf8Coder.of().encode(entry1.getKey(), out);
coder.encode(entry1.getValue(), out);
}
}
}
/**
* Restore the state {@code (stateName -> (valueCoder && (namespace -> value)))}
* for a given {@code keyGroupIdx}.
*
* @param keyGroupIdx the id of the key-group to be put in the snapshot.
* @param in the stream to read from.
* @param userCodeClassLoader the class loader that will be used to deserialize
* the valueCoder.
*/
public void restoreKeyGroupState(int keyGroupIdx, DataInputStream in,
ClassLoader userCodeClassLoader) throws Exception {
int localIdx = getIndexForKeyGroup(keyGroupIdx);
Map<String, Tuple2<Coder<?>, Map<String, ?>>> stateTable = stateTables[localIdx];
int numStates = in.readShort();
for (int i = 0; i < numStates; ++i) {
String stateName = in.readUTF();
Coder coder = InstantiationUtil.deserializeObject(in, userCodeClassLoader);
Tuple2<Coder<?>, Map<String, ?>> tuple2 = stateTable.get(stateName);
if (tuple2 == null) {
tuple2 = new Tuple2<>();
tuple2.f0 = coder;
tuple2.f1 = new HashMap<>();
stateTable.put(stateName, tuple2);
}
Map<String, Object> map = (Map<String, Object>) tuple2.f1;
int mapSize = in.readInt();
for (int j = 0; j < mapSize; j++) {
String namespace = StringUtf8Coder.of().decode(in);
Object value = coder.decode(in);
map.put(namespace, value);
}
}
}
}