/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.api.java; import org.apache.commons.lang3.StringUtils; import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.accumulators.Accumulator; import org.apache.flink.api.common.accumulators.SerializedListAccumulator; import org.apache.flink.api.common.accumulators.SimpleAccumulator; import org.apache.flink.api.common.io.RichOutputFormat; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.typeutils.GenericTypeInfo; import org.apache.flink.configuration.Configuration; import java.io.IOException; import java.lang.reflect.Field; import java.lang.reflect.Modifier; import java.util.Random; import static org.apache.flink.api.java.functions.FunctionAnnotation.SkipCodeAnalysis; /** * Utility class that contains helper methods to work with Java APIs. */ @Internal public final class Utils { public static final Random RNG = new Random(); public static String getCallLocationName() { return getCallLocationName(4); } public static String getCallLocationName(int depth) { StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace(); if (stackTrace.length < depth) { return "<unknown>"; } StackTraceElement elem = stackTrace[depth]; return String.format("%s(%s:%d)", elem.getMethodName(), elem.getFileName(), elem.getLineNumber()); } // -------------------------------------------------------------------------------------------- /** * Utility sink function that counts elements and writes the count into an accumulator, * from which it can be retrieved by the client. This sink is used by the * {@link DataSet#count()} function. * * @param <T> Type of elements to count. */ @SkipCodeAnalysis public static class CountHelper<T> extends RichOutputFormat<T> { private static final long serialVersionUID = 1L; private final String id; private long counter; public CountHelper(String id) { this.id = id; this.counter = 0L; } @Override public void configure(Configuration parameters) {} @Override public void open(int taskNumber, int numTasks) {} @Override public void writeRecord(T record) { counter++; } @Override public void close() { getRuntimeContext().getLongCounter(id).add(counter); } } /** * Utility sink function that collects elements into an accumulator, * from which it they can be retrieved by the client. This sink is used by the * {@link DataSet#collect()} function. * * @param <T> Type of elements to count. */ @SkipCodeAnalysis public static class CollectHelper<T> extends RichOutputFormat<T> { private static final long serialVersionUID = 1L; private final String id; private final TypeSerializer<T> serializer; private SerializedListAccumulator<T> accumulator; public CollectHelper(String id, TypeSerializer<T> serializer) { this.id = id; this.serializer = serializer; } @Override public void configure(Configuration parameters) {} @Override public void open(int taskNumber, int numTasks) { this.accumulator = new SerializedListAccumulator<>(); } @Override public void writeRecord(T record) throws IOException { accumulator.add(record, serializer); } @Override public void close() { // Important: should only be added in close method to minimize traffic of accumulators getRuntimeContext().addAccumulator(id, accumulator); } } public static class ChecksumHashCode implements SimpleAccumulator<ChecksumHashCode> { private static final long serialVersionUID = 1L; private long count; private long checksum; public ChecksumHashCode() {} public ChecksumHashCode(long count, long checksum) { this.count = count; this.checksum = checksum; } public long getCount() { return count; } public long getChecksum() { return checksum; } @Override public void add(ChecksumHashCode value) { this.count += value.count; this.checksum += value.checksum; } @Override public ChecksumHashCode getLocalValue() { return this; } @Override public void resetLocal() { this.count = 0; this.checksum = 0; } @Override public void merge(Accumulator<ChecksumHashCode, ChecksumHashCode> other) { this.add(other.getLocalValue()); } @Override public ChecksumHashCode clone() { return new ChecksumHashCode(count, checksum); } @Override public boolean equals(Object obj) { if (obj instanceof ChecksumHashCode) { ChecksumHashCode other = (ChecksumHashCode) obj; return this.count == other.count && this.checksum == other.checksum; } else { return false; } } @Override public int hashCode() { return (int) (this.count + this.checksum); } @Override public String toString() { return String.format("ChecksumHashCode 0x%016x, count %d", this.checksum, this.count); } } @SkipCodeAnalysis public static class ChecksumHashCodeHelper<T> extends RichOutputFormat<T> { private static final long serialVersionUID = 1L; private final String id; private long counter; private long checksum; public ChecksumHashCodeHelper(String id) { this.id = id; this.counter = 0L; this.checksum = 0L; } @Override public void configure(Configuration parameters) {} @Override public void open(int taskNumber, int numTasks) {} @Override public void writeRecord(T record) throws IOException { counter++; // convert 32-bit integer to non-negative long checksum += record.hashCode() & 0xffffffffL; } @Override public void close() throws IOException { ChecksumHashCode update = new ChecksumHashCode(counter, checksum); getRuntimeContext().addAccumulator(id, update); } } // -------------------------------------------------------------------------------------------- /** * Debugging utility to understand the hierarchy of serializers created by the Java API. * Tested in GroupReduceITCase.testGroupByGenericType() */ public static <T> String getSerializerTree(TypeInformation<T> ti) { return getSerializerTree(ti, 0); } private static <T> String getSerializerTree(TypeInformation<T> ti, int indent) { String ret = ""; if (ti instanceof CompositeType) { ret += StringUtils.repeat(' ', indent) + ti.getClass().getSimpleName()+"\n"; CompositeType<T> cti = (CompositeType<T>) ti; String[] fieldNames = cti.getFieldNames(); for (int i = 0; i < cti.getArity(); i++) { TypeInformation<?> fieldType = cti.getTypeAt(i); ret += StringUtils.repeat(' ', indent + 2) + fieldNames[i]+":"+getSerializerTree(fieldType, indent); } } else { if (ti instanceof GenericTypeInfo) { ret += StringUtils.repeat(' ', indent) + "GenericTypeInfo ("+ti.getTypeClass().getSimpleName()+")\n"; ret += getGenericTypeTree(ti.getTypeClass(), indent + 4); } else { ret += StringUtils.repeat(' ', indent) + ti.toString()+"\n"; } } return ret; } private static String getGenericTypeTree(Class<?> type, int indent) { String ret = ""; for (Field field : type.getDeclaredFields()) { if (Modifier.isStatic(field.getModifiers()) || Modifier.isTransient(field.getModifiers())) { continue; } ret += StringUtils.repeat(' ', indent) + field.getName() + ":" + field.getType().getName() + (field.getType().isEnum() ? " (is enum)" : "") + "\n"; if (!field.getType().isPrimitive()) { ret += getGenericTypeTree(field.getType(), indent + 4); } } return ret; } /** * Private constructor to prevent instantiation. */ private Utils() { throw new RuntimeException(); } }