/*
* Copyright (C) 2015 SoftIndex LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.datakernel.aggregation;
import com.google.common.base.Function;
import io.datakernel.aggregation.util.AsyncResultsTracker;
import io.datakernel.aggregation.util.AsyncResultsTracker.AsyncResultsTrackerList;
import io.datakernel.aggregation.util.BiPredicate;
import io.datakernel.async.ResultCallback;
import io.datakernel.codegen.DefiningClassLoader;
import io.datakernel.eventloop.Eventloop;
import io.datakernel.stream.AbstractStreamConsumer;
import io.datakernel.stream.StreamDataReceiver;
import io.datakernel.stream.StreamProducer;
import io.datakernel.stream.StreamProducers;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.*;
import static com.google.common.collect.Iterables.transform;
final class AggregationGroupReducer<T> extends AbstractStreamConsumer<T> implements StreamDataReceiver<T> {
private final Logger logger = LoggerFactory.getLogger(this.getClass());
private static final int MAX_OUTPUT_STREAMS = 1;
private final AggregationChunkStorage storage;
private final AggregationMetadataStorage metadataStorage;
private final Aggregation aggregation;
private final List<String> keys;
private final List<String> fields;
private final BiPredicate<T, T> partitionPredicate;
private final Class<?> recordClass;
private final Function<T, Comparable<?>> keyFunction;
private final Aggregate aggregate;
private final AsyncResultsTrackerList<AggregationChunk.NewChunk> resultsTracker;
private final AggregationOperationTracker operationTracker;
private final DefiningClassLoader classLoader;
private int chunkSize;
private final HashMap<Comparable<?>, Object> map = new HashMap<>();
public AggregationGroupReducer(Eventloop eventloop, AggregationChunkStorage storage,
final AggregationOperationTracker operationTracker,
AggregationMetadataStorage metadataStorage,
Aggregation aggregation, List<String> keys, List<String> fields,
Class<?> recordClass, BiPredicate<T, T> partitionPredicate,
Function<T, Comparable<?>> keyFunction, Aggregate aggregate,
int chunkSize, DefiningClassLoader classLoader,
final ResultCallback<List<AggregationChunk.NewChunk>> chunksCallback) {
super(eventloop);
this.storage = storage;
this.metadataStorage = metadataStorage;
this.keys = keys;
this.fields = fields;
this.partitionPredicate = partitionPredicate;
this.recordClass = recordClass;
this.keyFunction = keyFunction;
this.aggregate = aggregate;
this.chunkSize = chunkSize;
this.operationTracker = operationTracker;
this.aggregation = aggregation;
this.resultsTracker = AsyncResultsTracker.ofList(new ResultCallback<List<AggregationChunk.NewChunk>>() {
@Override
public void onResult(List<AggregationChunk.NewChunk> result) {
operationTracker.reportCompletion(AggregationGroupReducer.this);
chunksCallback.setResult(result);
}
@Override
public void onException(Exception e) {
operationTracker.reportCompletion(AggregationGroupReducer.this);
chunksCallback.setException(e);
}
});
this.classLoader = classLoader;
operationTracker.reportStart(this);
}
@Override
public StreamDataReceiver<T> getDataReceiver() {
return this;
}
@Override
public void onData(T item) {
Comparable<?> key = keyFunction.apply(item);
Object accumulator = map.get(key);
if (accumulator != null) {
aggregate.accumulate(accumulator, item);
} else {
accumulator = aggregate.createAccumulator(item);
map.put(key, accumulator);
if (map.size() == chunkSize) {
doFlush();
}
}
}
@SuppressWarnings("unchecked")
private void doFlush() {
if (map.isEmpty())
return;
resultsTracker.startOperation();
if (resultsTracker.getOperationsCount() > MAX_OUTPUT_STREAMS)
suspend();
final List<Map.Entry<Comparable<?>, Object>> entryList = new ArrayList<>(map.entrySet());
map.clear();
Collections.sort(entryList, new Comparator<Map.Entry<Comparable<?>, Object>>() {
@Override
public int compare(Map.Entry<Comparable<?>, Object> o1, Map.Entry<Comparable<?>, Object> o2) {
Comparable<Object> key1 = (Comparable<Object>) o1.getKey();
Comparable<Object> key2 = (Comparable<Object>) o2.getKey();
return key1.compareTo(key2);
}
});
Iterable<Object> list = transform(entryList, new Function<Map.Entry<Comparable<?>, Object>, Object>() {
@Override
public Object apply(Map.Entry<Comparable<?>, Object> input) {
return input.getValue();
}
});
final StreamProducer producer = StreamProducers.ofIterable(eventloop, list);
producer.streamTo(new AggregationChunker<>(eventloop, operationTracker, aggregation, keys, fields, recordClass, (BiPredicate) partitionPredicate, storage, metadataStorage, chunkSize, classLoader, new ResultCallback<List<AggregationChunk.NewChunk>>() {
@Override
protected void onResult(List<AggregationChunk.NewChunk> newChunks) {
resultsTracker.completeWithResults(newChunks);
if (resultsTracker.getOperationsCount() <= MAX_OUTPUT_STREAMS)
resume();
}
@Override
protected void onException(Exception e) {
logger.error("Streaming to chunker failed", e);
closeWithError(e);
resultsTracker.completeWithException(e);
}
}));
}
@Override
public void onEndOfStream() {
doFlush();
resultsTracker.shutDown();
}
@Override
protected void onError(Exception e) {
resultsTracker.shutDownWithException(e);
}
// jmx
public void flush() {
doFlush();
}
public int getBufferSize() {
return map.size();
}
public void setChunkSize(int chunkSize) {
this.chunkSize = chunkSize;
}
}