/* * Copyright © 2015 Cask Data, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of * the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations under * the License. */ package co.cask.cdap.etl.batch.spark; import co.cask.cdap.api.data.batch.Input; import co.cask.cdap.api.data.batch.InputFormatProvider; import co.cask.cdap.api.data.batch.Split; import co.cask.cdap.api.data.format.FormatSpecification; import co.cask.cdap.api.data.format.StructuredRecord; import co.cask.cdap.api.data.stream.StreamBatchReadable; import co.cask.cdap.api.spark.JavaSparkExecutionContext; import co.cask.cdap.api.stream.StreamEventDecoder; import com.google.common.base.Objects; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.mapreduce.InputFormat; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaSparkContext; import java.io.DataInput; import java.io.DataInputStream; import java.io.DataOutput; import java.io.DataOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.URI; import java.util.List; import java.util.Map; import javax.annotation.Nullable; import static java.lang.Thread.currentThread; /** * A POJO class for storing source information being set from {@link SparkBatchSourceContext} and used in * {@link ETLSparkProgram}. */ final class SparkBatchSourceFactory { private enum SourceType { STREAM(1), PROVIDER(2), DATASET(3); private final byte id; SourceType(int id) { this.id = (byte) id; } static SourceType from(byte id) { for (SourceType type : values()) { if (type.id == id) { return type; } } throw new IllegalArgumentException("No SourceType found for id " + id); } } private final StreamBatchReadable streamBatchReadable; private final InputFormatProvider inputFormatProvider; private final DatasetInfo datasetInfo; static SparkBatchSourceFactory create(StreamBatchReadable streamBatchReadable) { return new SparkBatchSourceFactory(streamBatchReadable, null, null); } static SparkBatchSourceFactory create(InputFormatProvider inputFormatProvider) { return new SparkBatchSourceFactory(null, inputFormatProvider, null); } static SparkBatchSourceFactory create(String datasetName) { return create(datasetName, ImmutableMap.<String, String>of()); } static SparkBatchSourceFactory create(String datasetName, Map<String, String> datasetArgs) { return create(datasetName, datasetArgs, null); } static SparkBatchSourceFactory create(String datasetName, Map<String, String> datasetArgs, @Nullable List<Split> splits) { return new SparkBatchSourceFactory(null, null, new DatasetInfo(datasetName, datasetArgs, splits)); } static SparkBatchSourceFactory create(Input input) { if (input instanceof Input.DatasetInput) { // Note if input format provider is trackable then it comes in as DatasetInput Input.DatasetInput datasetInput = (Input.DatasetInput) input; return create(datasetInput.getName(), datasetInput.getArguments(), datasetInput.getSplits()); } else if (input instanceof Input.StreamInput) { Input.StreamInput streamInput = (Input.StreamInput) input; return create(streamInput.getStreamBatchReadable()); } else if (input instanceof Input.InputFormatProviderInput) { Input.InputFormatProviderInput ifpInput = (Input.InputFormatProviderInput) input; return new SparkBatchSourceFactory(null, ifpInput.getInputFormatProvider(), null); } throw new IllegalArgumentException("Unknown input format type: " + input.getClass().getCanonicalName()); } static SparkBatchSourceFactory deserialize(InputStream inputStream) throws IOException { DataInput input = new DataInputStream(inputStream); // Deserialize based on the type switch (SourceType.from(input.readByte())) { case STREAM: return new SparkBatchSourceFactory(new StreamBatchReadable(URI.create(input.readUTF())), null, null); case PROVIDER: return new SparkBatchSourceFactory( null, new BasicInputFormatProvider( input.readUTF(), Serializations.deserializeMap(input, Serializations.createStringObjectReader())), null ); case DATASET: return new SparkBatchSourceFactory(null, null, DatasetInfo.deserialize(input)); } throw new IllegalArgumentException("Invalid input. Failed to decode SparkBatchSourceFactory."); } private SparkBatchSourceFactory(@Nullable StreamBatchReadable streamBatchReadable, @Nullable InputFormatProvider inputFormatProvider, @Nullable DatasetInfo datasetInfo) { this.streamBatchReadable = streamBatchReadable; this.inputFormatProvider = inputFormatProvider; this.datasetInfo = datasetInfo; } public void serialize(OutputStream outputStream) throws IOException { DataOutput output = new DataOutputStream(outputStream); if (streamBatchReadable != null) { output.writeByte(SourceType.STREAM.id); output.writeUTF(streamBatchReadable.toURI().toString()); return; } if (inputFormatProvider != null) { output.writeByte(SourceType.PROVIDER.id); output.writeUTF(inputFormatProvider.getInputFormatClassName()); Serializations.serializeMap(inputFormatProvider.getInputFormatConfiguration(), Serializations.createStringObjectWriter(), output); return; } if (datasetInfo != null) { output.writeByte(SourceType.DATASET.id); datasetInfo.serialize(output); return; } // This should never happen since the constructor is private and it only get calls from static create() methods // which make sure one and only one of those source type will be specified. throw new IllegalStateException("Unknown source type"); } @SuppressWarnings("unchecked") public <K, V> JavaPairRDD<K, V> createRDD(JavaSparkExecutionContext sec, JavaSparkContext jsc, Class<K> keyClass, Class<V> valueClass) { if (streamBatchReadable != null) { FormatSpecification formatSpec = streamBatchReadable.getFormatSpecification(); if (formatSpec != null) { return (JavaPairRDD<K, V>) sec.fromStream(streamBatchReadable.getStreamName(), formatSpec, streamBatchReadable.getStartTime(), streamBatchReadable.getEndTime(), StructuredRecord.class); } String decoderType = streamBatchReadable.getDecoderType(); if (decoderType == null) { return (JavaPairRDD<K, V>) sec.fromStream(streamBatchReadable.getStreamName(), streamBatchReadable.getStartTime(), streamBatchReadable.getEndTime(), valueClass); } else { try { Class<StreamEventDecoder<K, V>> decoderClass = (Class<StreamEventDecoder<K, V>>) Thread.currentThread().getContextClassLoader().loadClass(decoderType); return sec.fromStream(streamBatchReadable.getStreamName(), streamBatchReadable.getStartTime(), streamBatchReadable.getEndTime(), decoderClass, keyClass, valueClass); } catch (Exception e) { throw Throwables.propagate(e); } } } if (inputFormatProvider != null) { Configuration hConf = new Configuration(); hConf.clear(); for (Map.Entry<String, String> entry : inputFormatProvider.getInputFormatConfiguration().entrySet()) { hConf.set(entry.getKey(), entry.getValue()); } ClassLoader classLoader = Objects.firstNonNull(currentThread().getContextClassLoader(), getClass().getClassLoader()); try { @SuppressWarnings("unchecked") Class<InputFormat> inputFormatClass = (Class<InputFormat>) classLoader.loadClass( inputFormatProvider.getInputFormatClassName()); return jsc.newAPIHadoopRDD(hConf, inputFormatClass, keyClass, valueClass); } catch (ClassNotFoundException e) { throw Throwables.propagate(e); } } if (datasetInfo != null) { return sec.fromDataset(datasetInfo.getDatasetName(), datasetInfo.getDatasetArgs()); } // This should never happen since the constructor is private and it only get calls from static create() methods // which make sure one and only one of those source type will be specified. throw new IllegalStateException("Unknown source type"); } private static final class BasicInputFormatProvider implements InputFormatProvider { private final String inputFormatClassName; private final Map<String, String> configuration; private BasicInputFormatProvider(String inputFormatClassName, Map<String, String> configuration) { this.inputFormatClassName = inputFormatClassName; this.configuration = ImmutableMap.copyOf(configuration); } @Override public String getInputFormatClassName() { return inputFormatClassName; } @Override public Map<String, String> getInputFormatConfiguration() { return configuration; } } }