/*
* Copyright © 2015 Cask Data, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package co.cask.cdap.etl.batch.spark;
import co.cask.cdap.api.data.batch.Output;
import co.cask.cdap.api.data.batch.OutputFormatProvider;
import co.cask.cdap.api.spark.JavaSparkExecutionContext;
import com.google.common.collect.ImmutableMap;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.mapreduce.MRJobConfig;
import org.apache.spark.api.java.JavaPairRDD;
import java.io.DataInput;
import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
/**
* Handles writes to batch sinks. Maintains a mapping from sinks to their outputs and handles serialization and
* deserialization for those mappings.
*/
final class SparkBatchSinkFactory {
static SparkBatchSinkFactory deserialize(InputStream inputStream) throws IOException {
DataInput input = new DataInputStream(inputStream);
Map<String, OutputFormatProvider> outputFormatProviders = Serializations.deserializeMap(
input, new Serializations.ObjectReader<OutputFormatProvider>() {
@Override
public OutputFormatProvider read(DataInput input) throws IOException {
return new BasicOutputFormatProvider(input.readUTF(),
Serializations.deserializeMap(input,
Serializations.createStringObjectReader()));
}
});
Map<String, DatasetInfo> datasetInfos = Serializations.deserializeMap(
input, new Serializations.ObjectReader<DatasetInfo>() {
@Override
public DatasetInfo read(DataInput input) throws IOException {
return DatasetInfo.deserialize(input);
}
});
Map<String, Set<String>> sinkOutputs = Serializations.deserializeMap(
input, Serializations.createStringSetObjectReader());
return new SparkBatchSinkFactory(outputFormatProviders, datasetInfos, sinkOutputs);
}
private final Map<String, OutputFormatProvider> outputFormatProviders;
private final Map<String, DatasetInfo> datasetInfos;
private final Map<String, Set<String>> sinkOutputs;
SparkBatchSinkFactory() {
this.outputFormatProviders = new HashMap<>();
this.datasetInfos = new HashMap<>();
this.sinkOutputs = new HashMap<>();
}
private SparkBatchSinkFactory(Map<String, OutputFormatProvider> providers,
Map<String, DatasetInfo> datasetInfos,
Map<String, Set<String>> sinkOutputs) {
this.outputFormatProviders = providers;
this.datasetInfos = datasetInfos;
this.sinkOutputs = sinkOutputs;
}
void addOutput(String stageName, Output output) {
if (output instanceof Output.DatasetOutput) {
// Note if output format provider is trackable then it comes in as DatasetOutput
Output.DatasetOutput datasetOutput = (Output.DatasetOutput) output;
addOutput(stageName, datasetOutput.getName(), datasetOutput.getAlias(), datasetOutput.getArguments());
} else if (output instanceof Output.OutputFormatProviderOutput) {
Output.OutputFormatProviderOutput ofpOutput = (Output.OutputFormatProviderOutput) output;
addOutput(stageName, ofpOutput.getAlias(),
new BasicOutputFormatProvider(ofpOutput.getOutputFormatProvider().getOutputFormatClassName(),
ofpOutput.getOutputFormatProvider().getOutputFormatConfiguration()));
} else {
throw new IllegalArgumentException("Unknown output format type: " + output.getClass().getCanonicalName());
}
}
void addOutput(String stageName, String alias, OutputFormatProvider outputFormatProvider) {
addOutput(stageName, alias,
new BasicOutputFormatProvider(outputFormatProvider.getOutputFormatClassName(),
outputFormatProvider.getOutputFormatConfiguration()));
}
void addOutput(String stageName, String datasetName, Map<String, String> datasetArgs) {
addOutput(stageName, datasetName, datasetName, datasetArgs);
}
private void addOutput(String stageName, String alias,
BasicOutputFormatProvider outputFormatProvider) {
if (outputFormatProviders.containsKey(alias) || datasetInfos.containsKey(alias)) {
throw new IllegalArgumentException("Output already configured: " + alias);
}
outputFormatProviders.put(alias, outputFormatProvider);
addStageOutput(stageName, alias);
}
private void addOutput(String stageName, String datasetName, String alias, Map<String, String> datasetArgs) {
if (outputFormatProviders.containsKey(alias) || datasetInfos.containsKey(alias)) {
throw new IllegalArgumentException("Output already configured: " + alias);
}
datasetInfos.put(alias, new DatasetInfo(datasetName, datasetArgs, null));
addStageOutput(stageName, alias);
}
void serialize(OutputStream outputStream) throws IOException {
DataOutput output = new DataOutputStream(outputStream);
Serializations.serializeMap(outputFormatProviders, new Serializations.ObjectWriter<OutputFormatProvider>() {
@Override
public void write(OutputFormatProvider outputFormatProvider, DataOutput output) throws IOException {
output.writeUTF(outputFormatProvider.getOutputFormatClassName());
Serializations.serializeMap(outputFormatProvider.getOutputFormatConfiguration(),
Serializations.createStringObjectWriter(), output);
}
}, output);
Serializations.serializeMap(datasetInfos, new Serializations.ObjectWriter<DatasetInfo>() {
@Override
public void write(DatasetInfo datasetInfo, DataOutput output) throws IOException {
datasetInfo.serialize(output);
}
}, output);
Serializations.serializeMap(sinkOutputs, Serializations.createStringSetObjectWriter(), output);
}
<K, V> void writeFromRDD(JavaPairRDD<K, V> rdd, JavaSparkExecutionContext sec, String sinkName,
Class<K> keyClass, Class<V> valueClass) {
Set<String> outputNames = sinkOutputs.get(sinkName);
if (outputNames == null || outputNames.size() == 0) {
// should never happen if validation happened correctly at pipeline configure time
throw new IllegalArgumentException(sinkName + " has no outputs. " +
"Please check that the sink calls addOutput at some point.");
}
for (String outputName : outputNames) {
OutputFormatProvider outputFormatProvider = outputFormatProviders.get(outputName);
if (outputFormatProvider != null) {
Configuration hConf = new Configuration();
hConf.clear();
for (Map.Entry<String, String> entry : outputFormatProvider.getOutputFormatConfiguration().entrySet()) {
hConf.set(entry.getKey(), entry.getValue());
}
hConf.set(MRJobConfig.OUTPUT_FORMAT_CLASS_ATTR, outputFormatProvider.getOutputFormatClassName());
rdd.saveAsNewAPIHadoopDataset(hConf);
}
DatasetInfo datasetInfo = datasetInfos.get(outputName);
if (datasetInfo != null) {
sec.saveAsDataset(rdd, datasetInfo.getDatasetName(), datasetInfo.getDatasetArgs());
}
}
}
private void addStageOutput(String stageName, String outputName) {
Set<String> outputs = sinkOutputs.get(stageName);
if (outputs == null) {
outputs = new HashSet<>();
}
outputs.add(outputName);
sinkOutputs.put(stageName, outputs);
}
private static final class BasicOutputFormatProvider implements OutputFormatProvider {
private final String outputFormatClassName;
private final Map<String, String> configuration;
private BasicOutputFormatProvider(String outputFormatClassName, Map<String, String> configuration) {
this.outputFormatClassName = outputFormatClassName;
this.configuration = ImmutableMap.copyOf(configuration);
}
@Override
public String getOutputFormatClassName() {
return outputFormatClassName;
}
@Override
public Map<String, String> getOutputFormatConfiguration() {
return configuration;
}
}
}