/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.dstream;
import java.util.Iterator;
import java.util.List;
import java.util.ServiceLoader;
import java.util.stream.Stream;
import io.dstream.SerializableStreamAssets.SerFunction;
import io.dstream.support.CollectionFactory;
import io.dstream.utils.Assert;
import io.dstream.utils.Tuples.Tuple;
import io.dstream.utils.Tuples.Tuple2;
/**
* Implementation of {@link SerFunction} which will join multiple streams
* while applying user functionality at check points (see this{@link #addCheckPoint(int)}.
*/
class StreamJoinerFunction extends AbstractStreamMergingFunction {
private static final long serialVersionUID = -3615487628958776468L;
private static CollectionFactory collectionFactory;
static {
if (collectionFactory == null){
Iterator<CollectionFactory> sl = ServiceLoader
.load(CollectionFactory.class, ClassLoader.getSystemClassLoader()).iterator();
collectionFactory = sl.hasNext() ? sl.next() : null;
if (collectionFactory == null){
throw new IllegalStateException("Failed to find '" + CollectionFactory.class.getName() + "' provider.");
}
}
}
public StreamJoinerFunction(SerFunction<Stream<?>, Stream<?>> streamPreProcessingFunction) {
super(streamPreProcessingFunction);
}
/**
*
*/
@Override
protected Stream<?> doApply(List<Stream<?>> streamsList) {
Assert.notNull(streamsList, "'streamsList' must not be null");
Assert.isTrue(streamsList.size() >= 2, "There must be 2+ streams available to perform join. Was " + streamsList.size());
Stream<?> result = this.join(streamsList);
return result;
}
/**
*
*/
@SuppressWarnings({ "rawtypes", "unchecked" })
private Stream<?> join(List<Stream<?>> streams) {
Stream joinedStream = streams.remove(0);
int streamCount = 0;
int streamProcessedCounter = 2;
int procedureCount = 0;
Tuple2<Integer, Object> postJoinProcedure = null;
do {
if (this.checkPointProcedures.size() > 0){
if (this.checkPointProcedures.size() > procedureCount){
postJoinProcedure = this.checkPointProcedures.get(procedureCount++);
streamCount = postJoinProcedure._1();
}
}
joinedStream = this.doJoin(joinedStream, streams.remove(0));
if (streamCount == streamProcessedCounter){
SerFunction<Stream, Stream> postJoinProcedureFunction = (SerFunction) postJoinProcedure._2();
if (postJoinProcedureFunction != null){
joinedStream = postJoinProcedureFunction.apply(joinedStream);
}
}
streamProcessedCounter++;
} while (streams.size() > 0);
return joinedStream;
}
/**
*
*/
private Stream<?> doJoin(Stream<?> joinedStream, Stream<?> joiningStream) {
List<Object> joiningStreamCache = collectionFactory.newList();
return joinedStream.flatMap(lVal -> {
boolean cached = joiningStreamCache.size() > 0;
Stream<?> _joiningStream = cached ? joiningStreamCache.stream() : joiningStream;
try {
return _joiningStream.map(rVal -> {
if (!cached){
joiningStreamCache.add(rVal);
}
return this.mergeValues(lVal, rVal);
});
} catch (Exception e) {
throw new IllegalStateException("Failed to join partitions. Possible reason: The system may be trying to join on an empty partition. \n"
+ "This could happen due to the fact that your initial data was too small to be partitioned in the amount specified. \nPlease try"
+ " to lower dstream.parallelism size. ", e);
}
});
}
/**
*
*/
private Tuple mergeValues(Object left, Object right) {
Tuple current = left instanceof MergableTuple ? (MergableTuple)left : new MergableTuple(left);
Tuple cloned = current.size() > 1 ? current.clone() : current;
cloned.add(right);
return cloned;
}
/**
*
*/
private static class MergableTuple extends Tuple {
private static final long serialVersionUID = 6081720376172843799L;
MergableTuple(Object... values){
super(values);
}
}
}