package com.amazonaws.bigdatablog.edba.emr;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Properties;
import java.util.Random;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import com.amazonaws.auth.AWSSessionCredentials;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.AmazonS3Client;
import com.amazonaws.services.s3.model.S3Object;
import com.google.common.collect.ImmutableMap;
import scala.collection.Seq;
public class ProcessVendorTrasactions implements Serializable {
public static void run(String jobInputParam) throws Exception{
List<StructField> schemaFields = new ArrayList<StructField>();
schemaFields.add(DataTypes.createStructField("vendor_id", DataTypes.StringType, true));
schemaFields.add(DataTypes.createStructField("trans_amount", DataTypes.StringType, true));
schemaFields.add(DataTypes.createStructField("trans_type", DataTypes.StringType, true));
schemaFields.add(DataTypes.createStructField("item_id", DataTypes.StringType, true));
schemaFields.add(DataTypes.createStructField("trans_date", DataTypes.StringType, true));
StructType schema = DataTypes.createStructType(schemaFields);
SparkConf conf = new SparkConf().setAppName("Spark Redshift No Access-Keys");
SparkSession spark = SparkSession.builder().config(conf).getOrCreate();
JavaSparkContext sc = new JavaSparkContext(spark.sparkContext());
String redshiftJDBCURL=props.getProperty("redshift.jdbc.url");
String s3TempPath = props.getProperty("s3.temp.path");
System.out.println("props"+props);
JavaRDD<Row> salesRDD = sc.textFile(jobInputParam).
map(new Function<String,Row>(){public Row call(String saleRec){ String[] fields = saleRec.split(",");
return RowFactory.create(fields[0], fields[1],fields[2],fields[3],fields[4]);}});
Dataset<Row> salesDF = spark.createDataFrame(salesRDD,schema);
Dataset<Row> vendorItemSaleAmountDF = salesDF.filter(salesDF.col("trans_type").equalTo("4")).groupBy(salesDF.col("vendor_id"),salesDF.col("item_id"),salesDF.col("trans_date")).agg(ImmutableMap.of("trans_amount", "sum"));
Dataset<Row> vendorItemTaxAmountDF = salesDF.filter(salesDF.col("trans_type").equalTo("5")).groupBy(salesDF.col("vendor_id"),salesDF.col("item_id"),salesDF.col("trans_date")).agg(ImmutableMap.of("trans_amount", "sum"));
Dataset<Row> vendorItemDiscountAmountDF = salesDF.filter(salesDF.col("trans_type").equalTo("6")).groupBy(salesDF.col("vendor_id"),salesDF.col("item_id"),salesDF.col("trans_date")).agg(ImmutableMap.of("trans_amount", "sum"));
String[] joinColArray = {"vendor_id","item_id","trans_date"};
vendorItemSaleAmountDF.printSchema();
Seq<String> commonJoinColumns = scala.collection.JavaConversions.asScalaBuffer(Arrays.asList(joinColArray)).seq();
Dataset<Row> vendorAggregatedDF = vendorItemSaleAmountDF.join(vendorItemTaxAmountDF,commonJoinColumns,"left_outer")
.join(vendorItemDiscountAmountDF,commonJoinColumns,"left_outer")
.toDF("vendor_id","item_id","trans_date","sale_amount","tax_amount","discount_amount");
vendorAggregatedDF.printSchema();
DefaultAWSCredentialsProviderChain provider = new DefaultAWSCredentialsProviderChain();
AWSSessionCredentials creds = (AWSSessionCredentials) provider.getCredentials();
String appendix=new StringBuilder(String.valueOf(System.currentTimeMillis())).append("_").append(String.valueOf(new Random().nextInt(10)+1)).toString();
String vendorTransSummarySQL = new StringBuilder("begin transaction;delete from vendortranssummary using vendortranssummary_temp")
.append(appendix)
.append(" where vendortranssummary.vendor_id=vendortranssummary_temp")
.append(appendix)
.append(".vendor_id and vendortranssummary.item_id=vendortranssummary_temp")
.append(appendix)
.append(".item_id and vendortranssummary.trans_date = vendortranssummary_temp")
.append(appendix)
.append(".trans_date;")
.append("insert into vendortranssummary select * from vendortranssummary_temp")
.append(appendix)
.append(";drop table vendortranssummary_temp")
.append(appendix)
.append(";end transaction;").toString();
vendorAggregatedDF.write().format("com.databricks.spark.redshift").option("url", redshiftJDBCURL)
.option("dbtable", "vendortranssummary_temp"+appendix)
.option("usestagingtable","false")
.option("postactions",vendorTransSummarySQL)
.option("temporary_aws_access_key_id", creds.getAWSAccessKeyId())
.option("temporary_aws_secret_access_key",creds.getAWSSecretKey())
.option("temporary_aws_session_token", creds.getSessionToken())
.option("tempdir", s3TempPath).mode(SaveMode.Overwrite).save();
}
public static void main(String[] args) throws Exception{
String configBucket = args[0];
String configPrefix = args[1];
AmazonS3 s3Client = new AmazonS3Client();
S3Object s3Object = s3Client.getObject(configBucket,configPrefix);
props.load(s3Object.getObjectContent());
ProcessVendorTrasactions.run(args[2]);
}
static Properties props = new Properties();
}