/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.java.io;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.api.common.io.GenericInputFormat;
import org.apache.flink.api.common.io.NonParallelInput;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.core.io.GenericInputSplit;
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
/**
* An input format that returns objects from a collection.
*/
@PublicEvolving
public class CollectionInputFormat<T> extends GenericInputFormat<T> implements NonParallelInput {
private static final long serialVersionUID = 1L;
private static final int MAX_TO_STRING_LEN = 100;
private TypeSerializer<T> serializer;
private transient Collection<T> dataSet; // input data as collection. transient, because it will be serialized in a custom way
private transient Iterator<T> iterator;
public CollectionInputFormat(Collection<T> dataSet, TypeSerializer<T> serializer) {
if (dataSet == null) {
throw new NullPointerException();
}
this.serializer = serializer;
this.dataSet = dataSet;
}
@Override
public boolean reachedEnd() throws IOException {
return !this.iterator.hasNext();
}
@Override
public void open(GenericInputSplit split) throws IOException {
super.open(split);
this.iterator = this.dataSet.iterator();
}
@Override
public T nextRecord(T record) throws IOException {
return this.iterator.next();
}
// --------------------------------------------------------------------------------------------
private void writeObject(ObjectOutputStream out) throws IOException {
out.defaultWriteObject();
final int size = dataSet.size();
out.writeInt(size);
if (size > 0) {
DataOutputViewStreamWrapper wrapper = new DataOutputViewStreamWrapper(out);
for (T element : dataSet){
serializer.serialize(element, wrapper);
}
}
}
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
in.defaultReadObject();
int collectionLength = in.readInt();
List<T> list = new ArrayList<T>(collectionLength);
if (collectionLength > 0) {
try {
DataInputViewStreamWrapper wrapper = new DataInputViewStreamWrapper(in);
for (int i = 0; i < collectionLength; i++){
T element = serializer.deserialize(wrapper);
list.add(element);
}
}
catch (Throwable t) {
throw new IOException("Error while deserializing element from collection", t);
}
}
dataSet = list;
}
// --------------------------------------------------------------------------------------------
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append('[');
int num = 0;
for (T e : dataSet) {
sb.append(e);
if (num != dataSet.size() - 1) {
sb.append(", ");
if (sb.length() > MAX_TO_STRING_LEN) {
sb.append("...");
break;
}
}
num++;
}
sb.append(']');
return sb.toString();
}
// --------------------------------------------------------------------------------------------
public static <X> void checkCollection(Collection<X> elements, Class<X> viewedAs) {
if (elements == null || viewedAs == null) {
throw new NullPointerException();
}
for (X elem : elements) {
if (elem == null) {
throw new IllegalArgumentException("The collection must not contain null elements.");
}
// The second part of the condition is a workaround for the situation that can arise from eg.
// "env.fromElements((),(),())"
// In this situation, UnitTypeInfo.getTypeClass returns void.class (when we are in the Java world), but
// the actual objects that we will be working with, will be BoxedUnits.
// Note: TypeInformationGenTest.testUnit tests this condition.
if (!viewedAs.isAssignableFrom(elem.getClass()) &&
!(elem.getClass().toString().equals("class scala.runtime.BoxedUnit") && viewedAs.equals(void.class))) {
throw new IllegalArgumentException("The elements in the collection are not all subclasses of " +
viewedAs.getCanonicalName());
}
}
}
}