/*
* Copyright © 2015 Cask Data, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package co.cask.cdap.etl.batch.spark;
import co.cask.cdap.api.data.batch.Split;
import com.google.common.base.Objects;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.reflect.TypeParameter;
import com.google.common.reflect.TypeToken;
import com.google.gson.Gson;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.lang.reflect.Type;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
/**
* Information required to read from or write to a dataset in the spark program.
*/
final class DatasetInfo {
private final String datasetName;
private final Map<String, String> datasetArgs;
private final List<Split> splits;
public static DatasetInfo deserialize(DataInput input) throws IOException {
return new DatasetInfo(input.readUTF(),
Serializations.deserializeMap(input, Serializations.createStringObjectReader()),
deserializeSplits(input));
}
DatasetInfo(String datasetName, Map<String, String> datasetArgs, @Nullable List<Split> splits) {
this.datasetName = datasetName;
this.datasetArgs = ImmutableMap.copyOf(datasetArgs);
this.splits = splits == null ? null : ImmutableList.copyOf(splits);
}
public String getDatasetName() {
return datasetName;
}
public Map<String, String> getDatasetArgs() {
return datasetArgs;
}
@Nullable
public List<Split> getSplits() {
return splits;
}
void serialize(DataOutput output) throws IOException {
output.writeUTF(getDatasetName());
Serializations.serializeMap(getDatasetArgs(), Serializations.createStringObjectWriter(), output);
serializeSplits(getSplits(), output);
}
private static void serializeSplits(@Nullable List<Split> splits, DataOutput output) throws IOException {
if (splits == null || splits.isEmpty()) {
output.writeInt(0);
return;
}
// A bit hacky since we grab the split class name from the first element.
output.writeUTF(splits.get(0).getClass().getName());
output.writeUTF(new Gson().toJson(splits));
}
@Nullable
private static List<Split> deserializeSplits(DataInput input) throws IOException {
int size = input.readInt();
if (size == 0) {
return null;
}
ClassLoader classLoader = Objects.firstNonNull(Thread.currentThread().getContextClassLoader(),
SparkBatchSourceFactory.class.getClassLoader());
try {
Class<?> splitClass = classLoader.loadClass(input.readUTF());
return new Gson().fromJson(input.readUTF(), getListType(splitClass));
} catch (ClassNotFoundException e) {
throw new IOException("Unable to deserialize splits", e);
}
}
private static <T> Type getListType(Class<T> elementType) {
return new TypeToken<List<T>>() { }.where(new TypeParameter<T>() { }, elementType).getType();
}
}