/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.java.utils;
import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.api.common.JobExecutionResult;
import org.apache.flink.api.common.distributions.DataDistribution;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.operators.Keys;
import org.apache.flink.api.common.operators.base.PartitionOperatorBase;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.Utils;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.functions.SampleInCoordinator;
import org.apache.flink.api.java.functions.SampleInPartition;
import org.apache.flink.api.java.functions.SampleWithFraction;
import org.apache.flink.api.java.operators.GroupReduceOperator;
import org.apache.flink.api.java.operators.MapPartitionOperator;
import org.apache.flink.api.java.operators.PartitionOperator;
import org.apache.flink.api.java.summarize.aggregation.SummaryAggregatorFactory;
import org.apache.flink.api.java.summarize.aggregation.TupleSummaryAggregator;
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.api.java.typeutils.TupleTypeInfoBase;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.util.AbstractID;
import org.apache.flink.util.Collector;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
/**
* This class provides simple utility methods for zipping elements in a data set with an index
* or with a unique identifier.
*/
@PublicEvolving
public final class DataSetUtils {
/**
* Method that goes over all the elements in each partition in order to retrieve
* the total number of elements.
*
* @param input the DataSet received as input
* @return a data set containing tuples of subtask index, number of elements mappings.
*/
public static <T> DataSet<Tuple2<Integer, Long>> countElementsPerPartition(DataSet<T> input) {
return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer, Long>>() {
@Override
public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, Long>> out) throws Exception {
long counter = 0;
for (T value : values) {
counter++;
}
out.collect(new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(), counter));
}
});
}
/**
* Method that assigns a unique {@link Long} value to all elements in the input data set. The generated values are
* consecutive.
*
* @param input the input data set
* @return a data set of tuple 2 consisting of consecutive ids and initial values.
*/
public static <T> DataSet<Tuple2<Long, T>> zipWithIndex(DataSet<T> input) {
DataSet<Tuple2<Integer, Long>> elementCount = countElementsPerPartition(input);
return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() {
long start = 0;
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
List<Tuple2<Integer, Long>> offsets = getRuntimeContext().getBroadcastVariableWithInitializer(
"counts",
new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() {
@Override
public List<Tuple2<Integer, Long>> initializeBroadcastVariable(Iterable<Tuple2<Integer, Long>> data) {
// sort the list by task id to calculate the correct offset
List<Tuple2<Integer, Long>> sortedData = new ArrayList<>();
for (Tuple2<Integer, Long> datum : data) {
sortedData.add(datum);
}
Collections.sort(sortedData, new Comparator<Tuple2<Integer, Long>>() {
@Override
public int compare(Tuple2<Integer, Long> o1, Tuple2<Integer, Long> o2) {
return o1.f0.compareTo(o2.f0);
}
});
return sortedData;
}
});
// compute the offset for each partition
for (int i = 0; i < getRuntimeContext().getIndexOfThisSubtask(); i++) {
start += offsets.get(i).f1;
}
}
@Override
public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
for (T value: values) {
out.collect(new Tuple2<>(start++, value));
}
}
}).withBroadcastSet(elementCount, "counts");
}
/**
* Method that assigns a unique {@link Long} value to all elements in the input data set in the following way:
* <ul>
* <li> a map function is applied to the input data set
* <li> each map task holds a counter c which is increased for each record
* <li> c is shifted by n bits where n = log2(number of parallel tasks)
* <li> to create a unique ID among all tasks, the task id is added to the counter
* <li> for each record, the resulting counter is collected
* </ul>
*
* @param input the input data set
* @return a data set of tuple 2 consisting of ids and initial values.
*/
public static <T> DataSet<Tuple2<Long, T>> zipWithUniqueId (DataSet <T> input) {
return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() {
long maxBitSize = getBitSize(Long.MAX_VALUE);
long shifter = 0;
long start = 0;
long taskId = 0;
long label = 0;
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
shifter = getBitSize(getRuntimeContext().getNumberOfParallelSubtasks() - 1);
taskId = getRuntimeContext().getIndexOfThisSubtask();
}
@Override
public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
for (T value : values) {
label = (start << shifter) + taskId;
if (getBitSize(start) + shifter < maxBitSize) {
out.collect(new Tuple2<>(label, value));
start++;
} else {
throw new Exception("Exceeded Long value range while generating labels");
}
}
}
});
}
// --------------------------------------------------------------------------------------------
// Sample
// --------------------------------------------------------------------------------------------
/**
* Generate a sample of DataSet by the probability fraction of each element.
*
* @param withReplacement Whether element can be selected more than once.
* @param fraction Probability that each element is chosen, should be [0,1] without replacement,
* and [0, ∞) with replacement. While fraction is larger than 1, the elements are
* expected to be selected multi times into sample on average.
* @return The sampled DataSet
*/
public static <T> MapPartitionOperator<T, T> sample(
DataSet <T> input,
final boolean withReplacement,
final double fraction) {
return sample(input, withReplacement, fraction, Utils.RNG.nextLong());
}
/**
* Generate a sample of DataSet by the probability fraction of each element.
*
* @param withReplacement Whether element can be selected more than once.
* @param fraction Probability that each element is chosen, should be [0,1] without replacement,
* and [0, ∞) with replacement. While fraction is larger than 1, the elements are
* expected to be selected multi times into sample on average.
* @param seed random number generator seed.
* @return The sampled DataSet
*/
public static <T> MapPartitionOperator<T, T> sample(
DataSet <T> input,
final boolean withReplacement,
final double fraction,
final long seed) {
return input.mapPartition(new SampleWithFraction<T>(withReplacement, fraction, seed));
}
/**
* Generate a sample of DataSet which contains fixed size elements.
* <p>
* <strong>NOTE:</strong> Sample with fixed size is not as efficient as sample with fraction, use sample with
* fraction unless you need exact precision.
* </p>
*
* @param withReplacement Whether element can be selected more than once.
* @param numSamples The expected sample size.
* @return The sampled DataSet
*/
public static <T> DataSet<T> sampleWithSize(
DataSet <T> input,
final boolean withReplacement,
final int numSamples) {
return sampleWithSize(input, withReplacement, numSamples, Utils.RNG.nextLong());
}
/**
* Generate a sample of DataSet which contains fixed size elements.
* <p>
* <strong>NOTE:</strong> Sample with fixed size is not as efficient as sample with fraction, use sample with
* fraction unless you need exact precision.
* </p>
*
* @param withReplacement Whether element can be selected more than once.
* @param numSamples The expected sample size.
* @param seed Random number generator seed.
* @return The sampled DataSet
*/
public static <T> DataSet<T> sampleWithSize(
DataSet <T> input,
final boolean withReplacement,
final int numSamples,
final long seed) {
SampleInPartition<T> sampleInPartition = new SampleInPartition<>(withReplacement, numSamples, seed);
MapPartitionOperator mapPartitionOperator = input.mapPartition(sampleInPartition);
// There is no previous group, so the parallelism of GroupReduceOperator is always 1.
String callLocation = Utils.getCallLocationName();
SampleInCoordinator<T> sampleInCoordinator = new SampleInCoordinator<>(withReplacement, numSamples, seed);
return new GroupReduceOperator<>(mapPartitionOperator, input.getType(), sampleInCoordinator, callLocation);
}
// --------------------------------------------------------------------------------------------
// Partition
// --------------------------------------------------------------------------------------------
/**
* Range-partitions a DataSet on the specified tuple field positions.
*/
public static <T> PartitionOperator<T> partitionByRange(DataSet<T> input, DataDistribution distribution, int... fields) {
return new PartitionOperator<>(input, PartitionOperatorBase.PartitionMethod.RANGE, new Keys.ExpressionKeys<>(fields, input.getType(), false), distribution, Utils.getCallLocationName());
}
/**
* Range-partitions a DataSet on the specified fields.
*/
public static <T> PartitionOperator<T> partitionByRange(DataSet<T> input, DataDistribution distribution, String... fields) {
return new PartitionOperator<>(input, PartitionOperatorBase.PartitionMethod.RANGE, new Keys.ExpressionKeys<>(fields, input.getType()), distribution, Utils.getCallLocationName());
}
/**
* Range-partitions a DataSet using the specified key selector function.
*/
public static <T, K extends Comparable<K>> PartitionOperator<T> partitionByRange(DataSet<T> input, DataDistribution distribution, KeySelector<T, K> keyExtractor) {
final TypeInformation<K> keyType = TypeExtractor.getKeySelectorTypes(keyExtractor, input.getType());
return new PartitionOperator<>(input, PartitionOperatorBase.PartitionMethod.RANGE, new Keys.SelectorFunctionKeys<>(input.clean(keyExtractor), input.getType(), keyType), distribution, Utils.getCallLocationName());
}
// --------------------------------------------------------------------------------------------
// Summarize
// --------------------------------------------------------------------------------------------
/**
* Summarize a DataSet of Tuples by collecting single pass statistics for all columns
*
* Example usage:
* <pre>
* {@code
* Dataset<Tuple3<Double, String, Boolean>> input = // [...]
* Tuple3<NumericColumnSummary,StringColumnSummary, BooleanColumnSummary> summary = DataSetUtils.summarize(input)
*
* summary.f0.getStandardDeviation()
* summary.f1.getMaxLength()
* }
* </pre>
* @return the summary as a Tuple the same width as input rows
*/
public static <R extends Tuple, T extends Tuple> R summarize(DataSet<T> input) throws Exception {
if( !input.getType().isTupleType()) {
throw new IllegalArgumentException("summarize() is only implemented for DataSet's of Tuples");
}
final TupleTypeInfoBase<?> inType = (TupleTypeInfoBase<?>) input.getType();
DataSet<TupleSummaryAggregator<R>> result = input.mapPartition(new MapPartitionFunction<T, TupleSummaryAggregator<R>>() {
@Override
public void mapPartition(Iterable<T> values, Collector<TupleSummaryAggregator<R>> out) throws Exception {
TupleSummaryAggregator<R> aggregator = SummaryAggregatorFactory.create(inType);
for (Tuple value: values) {
aggregator.aggregate(value);
}
out.collect(aggregator);
}
}).reduce(new ReduceFunction<TupleSummaryAggregator<R>>() {
@Override
public TupleSummaryAggregator<R> reduce(TupleSummaryAggregator<R> agg1, TupleSummaryAggregator<R> agg2) throws Exception {
agg1.combine(agg2);
return agg1;
}
});
return result.collect().get(0).result();
}
// --------------------------------------------------------------------------------------------
// Checksum
// --------------------------------------------------------------------------------------------
/**
* Convenience method to get the count (number of elements) of a DataSet
* as well as the checksum (sum over element hashes).
*
* @return A ChecksumHashCode that represents the count and checksum of elements in the data set.
* @deprecated replaced with {@code org.apache.flink.graph.asm.dataset.ChecksumHashCode} in Gelly
*/
@Deprecated
public static <T> Utils.ChecksumHashCode checksumHashCode(DataSet<T> input) throws Exception {
final String id = new AbstractID().toString();
input.output(new Utils.ChecksumHashCodeHelper<T>(id)).name("ChecksumHashCode");
JobExecutionResult res = input.getExecutionEnvironment().execute();
return res.<Utils.ChecksumHashCode> getAccumulatorResult(id);
}
// *************************************************************************
// UTIL METHODS
// *************************************************************************
public static int getBitSize(long value){
if(value > Integer.MAX_VALUE) {
return 64 - Integer.numberOfLeadingZeros((int)(value >> 32));
} else {
return 32 - Integer.numberOfLeadingZeros((int)value);
}
}
/**
* Private constructor to prevent instantiation.
*/
private DataSetUtils() {
throw new RuntimeException();
}
}