/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.migration.runtime.checkpoint.savepoint; import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.core.fs.Path; import org.apache.flink.migration.runtime.checkpoint.KeyGroupState; import org.apache.flink.migration.runtime.checkpoint.SubtaskState; import org.apache.flink.migration.runtime.checkpoint.TaskState; import org.apache.flink.migration.runtime.state.AbstractStateBackend; import org.apache.flink.migration.runtime.state.KvStateSnapshot; import org.apache.flink.migration.runtime.state.StateHandle; import org.apache.flink.migration.runtime.state.filesystem.AbstractFileStateHandle; import org.apache.flink.migration.runtime.state.memory.SerializedStateHandle; import org.apache.flink.migration.state.MigrationKeyGroupStateHandle; import org.apache.flink.migration.state.MigrationStreamStateHandle; import org.apache.flink.migration.streaming.runtime.tasks.StreamTaskState; import org.apache.flink.migration.streaming.runtime.tasks.StreamTaskStateList; import org.apache.flink.migration.util.SerializedValue; import org.apache.flink.runtime.checkpoint.savepoint.SavepointSerializer; import org.apache.flink.runtime.checkpoint.savepoint.SavepointV2; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.state.ChainedStateHandle; import org.apache.flink.runtime.state.CheckpointStreamFactory; import org.apache.flink.runtime.state.KeyGroupRangeOffsets; import org.apache.flink.runtime.state.KeyGroupsStateHandle; import org.apache.flink.runtime.state.MultiStreamStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.filesystem.FileStateHandle; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory; import org.apache.flink.util.IOUtils; import org.apache.flink.util.InstantiationUtil; import org.apache.flink.util.Preconditions; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; /** * <p> * <p>In contrast to previous savepoint versions, this serializer makes sure * that no default Java serialization is used for serialization. Therefore, we * don't rely on any involved Java classes to stay the same. */ @SuppressWarnings("deprecation") public class SavepointV0Serializer implements SavepointSerializer<SavepointV2> { public static final SavepointV0Serializer INSTANCE = new SavepointV0Serializer(); private static final StreamStateHandle SIGNAL_0 = new ByteStreamStateHandle("SIGNAL_0", new byte[]{0}); private static final StreamStateHandle SIGNAL_1 = new ByteStreamStateHandle("SIGNAL_1", new byte[]{1}); private static final int MAX_SIZE = 4 * 1024 * 1024; private SavepointV0Serializer() { } @Override public void serialize(SavepointV2 savepoint, DataOutputStream dos) throws IOException { throw new UnsupportedOperationException("This serializer is read-only and only exists for backwards compatibility"); } @Override public SavepointV2 deserialize(DataInputStream dis, ClassLoader userClassLoader) throws IOException { long checkpointId = dis.readLong(); // Task states int numTaskStates = dis.readInt(); List<TaskState> taskStates = new ArrayList<>(numTaskStates); for (int i = 0; i < numTaskStates; i++) { JobVertexID jobVertexId = new JobVertexID(dis.readLong(), dis.readLong()); int parallelism = dis.readInt(); // Add task state TaskState taskState = new TaskState(jobVertexId, parallelism); taskStates.add(taskState); // Sub task states int numSubTaskStates = dis.readInt(); for (int j = 0; j < numSubTaskStates; j++) { int subtaskIndex = dis.readInt(); SerializedValue<StateHandle<?>> serializedValue = readSerializedValueStateHandle(dis); long stateSize = dis.readLong(); long duration = dis.readLong(); SubtaskState subtaskState = new SubtaskState( serializedValue, stateSize, duration); taskState.putState(subtaskIndex, subtaskState); } // Key group states int numKvStates = dis.readInt(); for (int j = 0; j < numKvStates; j++) { int keyGroupIndex = dis.readInt(); SerializedValue<StateHandle<?>> serializedValue = readSerializedValueStateHandle(dis); long stateSize = dis.readLong(); long duration = dis.readLong(); KeyGroupState keyGroupState = new KeyGroupState( serializedValue, stateSize, duration); taskState.putKvState(keyGroupIndex, keyGroupState); } } try { return convertSavepoint(taskStates, userClassLoader, checkpointId); } catch (Exception e) { throw new IOException(e); } } private static SerializedValue<StateHandle<?>> readSerializedValueStateHandle(DataInputStream dis) throws IOException { int length = dis.readInt(); SerializedValue<StateHandle<?>> serializedValue; if (length == -1) { serializedValue = new SerializedValue<>(null); } else { byte[] serializedData = new byte[length]; dis.readFully(serializedData, 0, length); serializedValue = SerializedValue.fromBytes(serializedData); } return serializedValue; } private SavepointV2 convertSavepoint( List<TaskState> taskStates, ClassLoader userClassLoader, long checkpointID) throws Exception { List<org.apache.flink.runtime.checkpoint.TaskState> newTaskStates = new ArrayList<>(taskStates.size()); for (TaskState taskState : taskStates) { newTaskStates.add(convertTaskState(taskState, userClassLoader, checkpointID)); } return new SavepointV2(checkpointID, newTaskStates); } private org.apache.flink.runtime.checkpoint.TaskState convertTaskState( TaskState taskState, ClassLoader userClassLoader, long checkpointID) throws Exception { JobVertexID jobVertexID = taskState.getJobVertexID(); int parallelism = taskState.getParallelism(); int chainLength = determineOperatorChainLength(taskState, userClassLoader); org.apache.flink.runtime.checkpoint.TaskState newTaskState = new org.apache.flink.runtime.checkpoint.TaskState( jobVertexID, parallelism, parallelism, chainLength); if (chainLength > 0) { Map<Integer, SubtaskState> subtaskStates = taskState.getSubtaskStatesById(); for (Map.Entry<Integer, SubtaskState> subtaskState : subtaskStates.entrySet()) { int parallelInstanceIdx = subtaskState.getKey(); newTaskState.putState(parallelInstanceIdx, convertSubtaskState( subtaskState.getValue(), parallelInstanceIdx, userClassLoader, checkpointID)); } } return newTaskState; } private org.apache.flink.runtime.checkpoint.SubtaskState convertSubtaskState( SubtaskState subtaskState, int parallelInstanceIdx, ClassLoader userClassLoader, long checkpointID) throws Exception { SerializedValue<StateHandle<?>> serializedValue = subtaskState.getState(); StreamTaskStateList stateList = (StreamTaskStateList) serializedValue.deserializeValue(userClassLoader); StreamTaskState[] streamTaskStates = stateList.getState(userClassLoader); List<StreamStateHandle> newChainStateList = Arrays.asList(new StreamStateHandle[streamTaskStates.length]); KeyGroupsStateHandle newKeyedState = null; for (int chainIdx = 0; chainIdx < streamTaskStates.length; ++chainIdx) { StreamTaskState streamTaskState = streamTaskStates[chainIdx]; if (streamTaskState == null) { continue; } newChainStateList.set(chainIdx, convertOperatorAndFunctionState(streamTaskState)); HashMap<String, KvStateSnapshot<?, ?, ?, ?>> oldKeyedState = streamTaskState.getKvStates(); if (null != oldKeyedState) { Preconditions.checkState(null == newKeyedState, "Found more than one keyed state in chain"); newKeyedState = convertKeyedBackendState(oldKeyedState, parallelInstanceIdx, checkpointID); } } ChainedStateHandle<StreamStateHandle> newChainedState = new ChainedStateHandle<>(newChainStateList); ChainedStateHandle<OperatorStateHandle> nopChain = new ChainedStateHandle<>(Arrays.asList(new OperatorStateHandle[newChainedState.getLength()])); return new org.apache.flink.runtime.checkpoint.SubtaskState( newChainedState, nopChain, nopChain, newKeyedState, null); } /** * This is public so that we can use it when restoring a legacy snapshot * in {@code AbstractStreamOperatorTestHarness}. */ public static StreamStateHandle convertOperatorAndFunctionState(StreamTaskState streamTaskState) throws Exception { List<StreamStateHandle> mergeStateHandles = new ArrayList<>(4); StateHandle<Serializable> functionState = streamTaskState.getFunctionState(); StateHandle<?> operatorState = streamTaskState.getOperatorState(); if (null != functionState) { mergeStateHandles.add(SIGNAL_1); mergeStateHandles.add(convertStateHandle(functionState)); } else { mergeStateHandles.add(SIGNAL_0); } if (null != operatorState) { mergeStateHandles.add(convertStateHandle(operatorState)); } return new MigrationStreamStateHandle(new MultiStreamStateHandle(mergeStateHandles)); } /** * This is public so that we can use it when restoring a legacy snapshot * in {@code AbstractStreamOperatorTestHarness}. */ public static KeyGroupsStateHandle convertKeyedBackendState( HashMap<String, KvStateSnapshot<?, ?, ?, ?>> oldKeyedState, int parallelInstanceIdx, long checkpointID) throws Exception { if (null != oldKeyedState) { CheckpointStreamFactory checkpointStreamFactory = new MemCheckpointStreamFactory(MAX_SIZE); CheckpointStreamFactory.CheckpointStateOutputStream keyedStateOut = checkpointStreamFactory.createCheckpointStateOutputStream(checkpointID, 0L); try { final long offset = keyedStateOut.getPos(); InstantiationUtil.serializeObject(keyedStateOut, oldKeyedState); StreamStateHandle streamStateHandle = keyedStateOut.closeAndGetHandle(); keyedStateOut = null; // makes IOUtils.closeQuietly(...) ignore this if (null != streamStateHandle) { KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(parallelInstanceIdx, parallelInstanceIdx, new long[]{offset}); return new MigrationKeyGroupStateHandle(keyGroupRangeOffsets, streamStateHandle); } } finally { IOUtils.closeQuietly(keyedStateOut); } } return null; } private int determineOperatorChainLength( TaskState taskState, ClassLoader userClassLoader) throws IOException, ClassNotFoundException { Collection<SubtaskState> subtaskStates = taskState.getStates(); if (subtaskStates == null || subtaskStates.isEmpty()) { return 0; } SubtaskState firstSubtaskState = subtaskStates.iterator().next(); Object toCastTaskStateList = firstSubtaskState.getState().deserializeValue(userClassLoader); if (toCastTaskStateList instanceof StreamTaskStateList) { StreamTaskStateList taskStateList = (StreamTaskStateList) toCastTaskStateList; StreamTaskState[] streamTaskStates = taskStateList.getState(userClassLoader); return streamTaskStates.length; } return 0; } /** * This is public so that we can use it when restoring a legacy snapshot * in {@code AbstractStreamOperatorTestHarness}. */ public static StreamStateHandle convertStateHandle(StateHandle<?> oldStateHandle) throws Exception { if (oldStateHandle instanceof AbstractFileStateHandle) { Path path = ((AbstractFileStateHandle) oldStateHandle).getFilePath(); return new FileStateHandle(path, oldStateHandle.getStateSize()); } else if (oldStateHandle instanceof SerializedStateHandle) { byte[] data = ((SerializedStateHandle<?>) oldStateHandle).getSerializedData(); return new ByteStreamStateHandle(String.valueOf(System.identityHashCode(data)), data); } else if (oldStateHandle instanceof org.apache.flink.migration.runtime.state.memory.ByteStreamStateHandle) { byte[] data = ((org.apache.flink.migration.runtime.state.memory.ByteStreamStateHandle) oldStateHandle).getData(); return new ByteStreamStateHandle(String.valueOf(System.identityHashCode(data)), data); } else if (oldStateHandle instanceof AbstractStateBackend.DataInputViewHandle) { return convertStateHandle( ((AbstractStateBackend.DataInputViewHandle) oldStateHandle).getStreamStateHandle()); } throw new IllegalArgumentException("Unknown state handle type: " + oldStateHandle); } @VisibleForTesting public void serializeOld(SavepointV0 savepoint, DataOutputStream dos) throws IOException { dos.writeLong(savepoint.getCheckpointId()); Collection<org.apache.flink.migration.runtime.checkpoint.TaskState> taskStates = savepoint.getOldTaskStates(); dos.writeInt(taskStates.size()); for (org.apache.flink.migration.runtime.checkpoint.TaskState taskState : savepoint.getOldTaskStates()) { // Vertex ID dos.writeLong(taskState.getJobVertexID().getLowerPart()); dos.writeLong(taskState.getJobVertexID().getUpperPart()); // Parallelism int parallelism = taskState.getParallelism(); dos.writeInt(parallelism); // Sub task states dos.writeInt(taskState.getNumberCollectedStates()); for (int i = 0; i < parallelism; i++) { SubtaskState subtaskState = taskState.getState(i); if (subtaskState != null) { dos.writeInt(i); SerializedValue<?> serializedValue = subtaskState.getState(); if (serializedValue == null) { dos.writeInt(-1); // null } else { byte[] serialized = serializedValue.getByteArray(); dos.writeInt(serialized.length); dos.write(serialized, 0, serialized.length); } dos.writeLong(subtaskState.getStateSize()); dos.writeLong(subtaskState.getDuration()); } } // Key group states dos.writeInt(taskState.getNumberCollectedKvStates()); for (int i = 0; i < parallelism; i++) { KeyGroupState keyGroupState = taskState.getKvState(i); if (keyGroupState != null) { dos.write(i); SerializedValue<?> serializedValue = keyGroupState.getKeyGroupState(); if (serializedValue == null) { dos.writeInt(-1); // null } else { byte[] serialized = serializedValue.getByteArray(); dos.writeInt(serialized.length); dos.write(serialized, 0, serialized.length); } dos.writeLong(keyGroupState.getStateSize()); dos.writeLong(keyGroupState.getDuration()); } } } } }