/*
* Copyright (C) 2015 SoftIndex LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.datakernel.aggregation;
import com.google.common.base.Function;
import com.google.common.collect.Sets;
import io.datakernel.aggregation.annotation.Key;
import io.datakernel.aggregation.annotation.Measures;
import io.datakernel.aggregation.fieldtype.FieldType;
import io.datakernel.aggregation.measure.Measure;
import io.datakernel.aggregation.util.BiPredicate;
import io.datakernel.aggregation.util.Predicates;
import io.datakernel.codegen.*;
import io.datakernel.serializer.BufferSerializer;
import io.datakernel.serializer.SerializerBuilder;
import io.datakernel.serializer.asm.SerializerGenClass;
import io.datakernel.stream.processor.StreamMap;
import io.datakernel.stream.processor.StreamReducers;
import io.datakernel.util.WithValue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.*;
import static com.google.common.base.Functions.forMap;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Throwables.propagate;
import static com.google.common.collect.Iterables.concat;
import static com.google.common.collect.Lists.newArrayList;
import static com.google.common.collect.Maps.toMap;
import static com.google.common.collect.Maps.transformValues;
import static com.google.common.collect.Sets.newLinkedHashSet;
import static io.datakernel.codegen.Expressions.*;
/**
* Defines a structure of an aggregation.
* It is defined by keys, fields and their types.
* Contains methods for defining dynamic classes, that are used for different operations.
* Provides serializer for records that have the defined structure.
*/
@SuppressWarnings({"rawtypes", "unchecked"})
public class AggregationUtils {
private final static Logger logger = LoggerFactory.getLogger(AggregationUtils.class);
private AggregationUtils() {
}
public static Map<String, FieldType> projectKeys(Map<String, FieldType> keyTypes, List<String> keys) {
return projectMap(keyTypes, keys);
}
public static Map<String, FieldType> projectFields(Map<String, FieldType> fieldTypes, List<String> fields) {
return projectMap(fieldTypes, fields);
}
public static Map<String, Measure> projectMeasures(Map<String, Measure> measures, List<String> fields) {
return projectMap(measures, fields);
}
public static Map<String, FieldType> measuresAsFields(Map<String, Measure> measures) {
return transformValues(measures, new Function<Measure, FieldType>() {
@Override
public FieldType apply(Measure input) {
return input.getFieldType();
}
});
}
private static <K, V> Map<K, V> projectMap(Map<K, V> map, Collection<K> keys) {
keys = new HashSet<>(keys);
checkArgument(map.keySet().containsAll(keys), "Unknown fields: " + Sets.difference(newLinkedHashSet(keys), map.keySet()));
LinkedHashMap<K, V> result = new LinkedHashMap<>();
for (Map.Entry<K, V> entry : map.entrySet()) {
if (keys.contains(entry.getKey())) {
result.put(entry.getKey(), entry.getValue());
}
}
return result;
}
public static Class<?> createKeyClass(Aggregation aggregation, List<String> keys, DefiningClassLoader classLoader) {
return createKeyClass(projectKeys(aggregation.getKeyTypes(), keys), classLoader);
}
public static Class<?> createKeyClass(Map<String, FieldType> keys, DefiningClassLoader classLoader) {
List<String> keyList = new ArrayList<>(keys.keySet());
return ClassBuilder.create(classLoader, Comparable.class)
.withFields(transformValues(keys, new Function<FieldType, Class<?>>() {
@Override
public Class<?> apply(FieldType field) {
return field.getInternalDataType();
}
}))
.withMethod("compareTo", compareTo(keyList))
.withMethod("equals", asEquals(keyList))
.withMethod("hashCode", hashCodeOfThis(keyList))
.withMethod("toString", asString(keyList)).build();
}
public static Comparator createKeyComparator(Class<?> recordClass, List<String> keys, DefiningClassLoader classLoader) {
return ClassBuilder.create(classLoader, Comparator.class)
.withMethod("compare", compare(recordClass, keys))
.buildClassAndCreateNewInstance();
}
public static StreamMap.MapperProjection createMapper(final Class<?> recordClass, final Class<?> resultClass,
final List<String> keys, final List<String> fields,
DefiningClassLoader classLoader) {
return ClassBuilder.create(classLoader, StreamMap.MapperProjection.class)
.withMethod("apply", new WithValue<Expression>() {
@Override
public Expression get() {
Expression result1 = let(constructor(resultClass));
ExpressionSequence sequence = ExpressionSequence.create();
for (String fieldName : concat(keys, fields)) {
sequence.add(set(
field(result1, fieldName),
field(cast(arg(0), recordClass), fieldName)));
}
return sequence.add(result1);
}
}.get())
.buildClassAndCreateNewInstance();
}
public static Function createKeyFunction(final Class<?> recordClass, final Class<?> keyClass,
final List<String> keys,
DefiningClassLoader classLoader) {
return ClassBuilder.create(classLoader, Function.class)
.withMethod("apply", new WithValue<Expression>() {
@Override
public Expression get() {
Expression key = let(constructor(keyClass));
ExpressionSequence sequence = ExpressionSequence.create();
for (String keyString : keys) {
sequence.add(set(
field(key, keyString),
field(cast(arg(0), recordClass), keyString)));
}
return sequence.add(key);
}
}.get())
.buildClassAndCreateNewInstance();
}
public static Class<?> createRecordClass(Aggregation aggregation,
List<String> keys, List<String> fields,
DefiningClassLoader classLoader) {
return createRecordClass(
projectKeys(aggregation.getKeyTypes(), keys),
projectFields(aggregation.getMeasureTypes(), fields),
classLoader);
}
public static Class<?> createRecordClass(Map<String, FieldType> keys, Map<String, FieldType> fields,
DefiningClassLoader classLoader) {
return ClassBuilder.create(classLoader, Object.class)
.withFields(transformValues(keys, new Function<FieldType, Class<?>>() {
@Override
public Class<?> apply(FieldType fieldType) {
return fieldType.getInternalDataType();
}
}))
.withFields(transformValues(fields, new Function<FieldType, Class<?>>() {
@Override
public Class<?> apply(FieldType fieldType) {
return fieldType.getInternalDataType();
}
}))
.withMethod("toString", asString(newArrayList(concat(keys.keySet(), fields.keySet()))))
.build();
}
public static <T> BufferSerializer<T> createBufferSerializer(Aggregation aggregation, Class<T> recordClass,
List<String> keys, List<String> fields,
DefiningClassLoader classLoader) {
return createBufferSerializer(recordClass,
toMap(keys, forMap(aggregation.getKeyTypes())),
toMap(fields, forMap(aggregation.getMeasureTypes())),
classLoader);
}
private static <T> BufferSerializer<T> createBufferSerializer(Class<T> recordClass,
Map<String, FieldType> keys, Map<String, FieldType> fields,
DefiningClassLoader classLoader) {
SerializerGenClass serializerGenClass = new SerializerGenClass(recordClass);
for (String key : keys.keySet()) {
FieldType keyType = keys.get(key);
try {
Field recordClassKey = recordClass.getField(key);
serializerGenClass.addField(recordClassKey, keyType.getSerializer(), -1, -1);
} catch (NoSuchFieldException e) {
throw propagate(e);
}
}
for (String field : fields.keySet()) {
try {
Field recordClassField = recordClass.getField(field);
serializerGenClass.addField(recordClassField, fields.get(field).getSerializer(), -1, -1);
} catch (NoSuchFieldException e) {
throw propagate(e);
}
}
return SerializerBuilder.create(classLoader).build(serializerGenClass);
}
public static StreamReducers.Reducer aggregationReducer(Aggregation aggregation, Class<?> inputClass, Class<?> outputClass,
List<String> keys, List<String> fields,
DefiningClassLoader classLoader) {
Expression accumulator = let(constructor(outputClass));
ExpressionSequence onFirstItem = ExpressionSequence.create();
ExpressionSequence onNextItem = ExpressionSequence.create();
for (String key : keys) {
onFirstItem.add(set(
field(accumulator, key),
field(cast(arg(2), inputClass), key)));
}
for (String field : fields) {
Measure aggregateFunction = aggregation.getMeasure(field);
onFirstItem.add(aggregateFunction.initAccumulatorWithAccumulator(
field(accumulator, field),
field(cast(arg(2), inputClass), field)
));
onNextItem.add(aggregateFunction.reduce(
field(cast(arg(3), outputClass), field),
field(cast(arg(2), inputClass), field)
));
}
onFirstItem.add(accumulator);
onNextItem.add(arg(3));
return ClassBuilder.create(classLoader, StreamReducers.Reducer.class)
.withMethod("onFirstItem", onFirstItem)
.withMethod("onNextItem", onNextItem)
.withMethod("onComplete", call(arg(0), "onData", arg(2)))
.buildClassAndCreateNewInstance();
}
public static Aggregate createPreaggregator(Aggregation aggregation, Class<?> inputClass, Class<?> outputClass,
Map<String, String> keyFields, Map<String, String> measureFields,
DefiningClassLoader classLoader) {
Expression accumulator = let(constructor(outputClass));
ExpressionSequence createAccumulator = ExpressionSequence.create();
ExpressionSequence accumulate = ExpressionSequence.create();
for (String key : keyFields.keySet()) {
String inputField = keyFields.get(key);
createAccumulator.add(set(
field(accumulator, key),
field(cast(arg(0), inputClass), inputField)));
}
for (String measure : measureFields.keySet()) {
String inputFields = measureFields.get(measure);
Measure aggregateFunction = aggregation.getMeasure(measure);
createAccumulator.add(aggregateFunction.initAccumulatorWithValue(
field(accumulator, measure),
inputFields == null ? null : field(cast(arg(0), inputClass), inputFields)));
accumulate.add(aggregateFunction.accumulate(
field(cast(arg(0), outputClass), measure),
inputFields == null ? null : field(cast(arg(1), inputClass), inputFields)));
}
createAccumulator.add(accumulator);
return ClassBuilder.create(classLoader, Aggregate.class)
.withMethod("createAccumulator", createAccumulator)
.withMethod("accumulate", accumulate)
.buildClassAndCreateNewInstance();
}
public static BiPredicate createPartitionPredicate(Class recordClass, List<String> partitioningKey,
DefiningClassLoader classLoader) {
if (partitioningKey.isEmpty())
return Predicates.alwaysTrue();
PredicateDefAnd predicate = PredicateDefAnd.create();
for (String keyComponent : partitioningKey) {
predicate.add(cmpEq(
field(cast(arg(0), recordClass), keyComponent),
field(cast(arg(1), recordClass), keyComponent)));
}
return ClassBuilder.create(classLoader, BiPredicate.class)
.withMethod("test", predicate)
.buildClassAndCreateNewInstance();
}
public static <T> Map<String, String> scanKeyFields(Class<T> inputClass) {
Map<String, String> keyFields = new LinkedHashMap<>();
for (Field field : inputClass.getFields()) {
for (Annotation annotation : field.getAnnotations()) {
if (annotation.annotationType() == Key.class) {
String value = ((Key) annotation).value();
keyFields.put("".equals(value) ? field.getName() : value, field.getName());
}
}
}
for (Method method : inputClass.getMethods()) {
for (Annotation annotation : method.getAnnotations()) {
if (annotation.annotationType() == Key.class) {
String value = ((Key) annotation).value();
keyFields.put("".equals(value) ? method.getName() : value, method.getName());
}
}
}
checkArgument(!keyFields.isEmpty(), "Missing @Key annotations in %s", inputClass);
return keyFields;
}
public static <T> Map<String, String> scanMeasureFields(Class<T> inputClass) {
Map<String, String> measureFields = new LinkedHashMap<>();
for (Annotation annotation : inputClass.getAnnotations()) {
if (annotation.annotationType() == Measures.class) {
for (String measure : ((Measures) annotation).value()) {
measureFields.put(measure, null);
}
}
}
for (Field field : inputClass.getFields()) {
for (Annotation annotation : field.getAnnotations()) {
if (annotation.annotationType() == Measures.class) {
for (String measure : ((Measures) annotation).value()) {
measureFields.put(measure, field.getName());
}
}
}
}
for (Field field : inputClass.getFields()) {
for (Annotation annotation : field.getAnnotations()) {
if (annotation.annotationType() == io.datakernel.aggregation.annotation.Measure.class) {
String value = ((io.datakernel.aggregation.annotation.Measure) annotation).value();
measureFields.put("".equals(value) ? field.getName() : value, field.getName());
}
}
}
for (Method method : inputClass.getMethods()) {
for (Annotation annotation : method.getAnnotations()) {
if (annotation.annotationType() == io.datakernel.aggregation.annotation.Measure.class) {
String value = ((io.datakernel.aggregation.annotation.Measure) annotation).value();
measureFields.put("".equals(value) ? method.getName() : value, method.getName());
}
}
}
checkArgument(!measureFields.isEmpty(), "Missing @Measure(s) annotations in %s", inputClass);
return measureFields;
}
}