package com.caseystella.input; import com.caseystella.util.ConversionUtils; import com.google.common.base.Splitter; import com.google.common.collect.Iterables; import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.DataFrameReader; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.types.*; import scala.Tuple2; import java.util.ArrayList; import java.util.EnumMap; import java.util.List; import java.util.Map; public class CSVHandler implements InputHandler { public enum Options { QUERY("query") ,HAS_HEADER("hasHeader") ,INFER_SCHEMA("inferSchema" ) ,EXPLICIT_SCHEMA("withSchema" ) ; String optionName; Options(String optionName) { this.optionName = optionName; } public String get(Map<String, String> properties) { return properties.get(optionName); } public <T> T get(Map<String, String> properties, T defaultVal, Class<T> clazz) { String s = properties.get(optionName); if(s == null) { return defaultVal; } T ret = ConversionUtils.convert(s, clazz); if(ret == null) { return defaultVal; } return ret; } public boolean has(Map<String, String> properties) { return properties.containsKey(optionName); } } public enum SqlTypes { STRING(DataTypes.StringType), INTEGER(DataTypes.IntegerType), DATE(DataTypes.DateType), DOUBLE(DataTypes.DoubleType), FLOAT(DataTypes.FloatType), LONG(DataTypes.LongType); DataType dt; SqlTypes(DataType dt) { this.dt = dt; } public DataType getDataType() { return dt; } } private StructType customSchema(String schemaDef) { List<Tuple2<String, DataType>> schema = new ArrayList<>(); for(String i : Splitter.on(",").split(schemaDef)) { String columnName = i; DataType dt = DataTypes.StringType; if(i.contains(":")) { Iterable<String> tokens = Splitter.on(":").split(i); columnName = Iterables.getFirst(tokens, null); dt = SqlTypes.valueOf(Iterables.getLast(tokens, "").toUpperCase()).getDataType(); } schema.add(new Tuple2<>(columnName, dt)); } StructField[] fields = new StructField[schema.size()]; for(int i = 0;i < schema.size();++i) { Tuple2<String, DataType> s = schema.get(i); fields[i] = new StructField(s._1, s._2, true, Metadata.empty()); } return new StructType(fields); } @Override public DataFrame open(String inputName, JavaSparkContext sc, Map<String, String> properties) { SQLContext sqlContext = new SQLContext(sc); DataFrameReader reader = sqlContext.read() .format("com.databricks.spark.csv") .option("header", Options.HAS_HEADER.get(properties, "true", String.class)) .option("inferSchema", Options.INFER_SCHEMA.get(properties, "true", String.class)) ; if(Options.EXPLICIT_SCHEMA.has(properties)) { reader = reader.schema(customSchema(Options.EXPLICIT_SCHEMA.get(properties))); } if(Options.QUERY.has(properties)) { DataFrame df = reader.load(inputName); String tableName = Iterables.getFirst(Splitter.on('.').split(Iterables.getLast(Splitter.on('/').split(inputName), inputName)), inputName); System.out.println("Registering " + tableName + "..."); df.registerTempTable(tableName); return df.sqlContext().sql(Options.QUERY.get(properties)); } else { return reader.load(inputName); } } }