/* Copyright (c) 2005 - 2012 Vertica, an HP company -*- Java -*- */ package com.vertica.hadoop; import java.io.IOException; import java.text.DateFormat; import java.text.ParseException; import java.util.ArrayList; import java.util.Collection; import java.util.HashSet; import java.util.List; import java.sql.Connection; import java.sql.DatabaseMetaData; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Statement; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.mapreduce.InputFormat; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.mapreduce.JobContext; import org.apache.hadoop.mapreduce.RecordReader; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.mapreduce.TaskAttemptContext; /** * Input formatter that returns the results of a query executed against Vertica. * The key is a record number within the result set of each mapper The value is * a VerticaRecord, which uses a similar interface to JDBC ResultSets for * returning values. * */ public class VerticaInputFormat extends InputFormat<LongWritable, VerticaRecord> { private static final Log LOG = LogFactory.getLog("com.vertica.hadoop"); private String inputQuery = null; private String params = null; public VerticaInputFormat() {} /** * Set a parameterized input query for a job and the query that returns the * parameters. * * @param query * SQL query that has parameters specified by question marks ("?") * @param params * SQL query that returns parameters for the input query or * the parameters to substiture */ public VerticaInputFormat(String query, String params) { inputQuery = query; this.params = params; } /** * Set the input query for a job * * @param job * @param inputQuery * query to run against Vertica */ public static void setInput(Job job, String inputQuery) { job.setInputFormatClass(VerticaInputFormat.class); VerticaConfiguration config = new VerticaConfiguration(job.getConfiguration()); config.setInputQuery(inputQuery); } /** * Set a parameterized input query for a job and the query that returns the * parameters. * * @param job * @param inputQuery * SQL query that has parameters specified by question marks ("?") * @param segmentParamsQuery * SQL query that returns parameters for the input query */ public static void setInput(Job job, String inputQuery, String segmentParamsQuery) { job.setInputFormatClass(VerticaInputFormat.class); VerticaConfiguration config = new VerticaConfiguration(job.getConfiguration()); config.setInputQuery(inputQuery); config.setParamsQuery(segmentParamsQuery); } /** * Set the input query and any number of comma delimited literal list of * parameters * * @param job * @param inputQuery * SQL query that has parameters specified by question marks ("?") * @param segmentParams * any numer of comma delimited strings with literal parameters to * substitute in the input query */ @SuppressWarnings("serial") public static void setInput(Job job, String inputQuery, String... segmentParams) throws IOException { // transform each param set into array DateFormat datefmt = DateFormat.getDateInstance(); Collection<List<Object>> params = new HashSet<List<Object>>() {}; for (String strParams : segmentParams) { List<Object> param = new ArrayList<Object>(); for (String strParam : strParams.split(",")) { strParam = strParam.trim(); if (strParam.charAt(0) == '\'' && strParam.charAt(strParam.length() - 1) == '\'') param.add(strParam.substring(1, strParam.length() - 1)); else { try { param.add(datefmt.parse(strParam)); } catch (ParseException e1) { try { param.add(Integer.parseInt(strParam)); } catch (NumberFormatException e2) { throw new IOException("Error parsing argument " + strParam); } } } } params.add(param); } setInput(job, inputQuery, params); } /** * Set the input query and a collection of parameter lists * * @param job * @param inpuQuery * SQL query that has parameters specified by question marks ("?") * @param segmentParams * collection of ordered lists to subtitute into the input query * @throws IOException */ public static void setInput(Job job, String inpuQuery, Collection<List<Object>> segmentParams) throws IOException { job.setInputFormatClass(VerticaInputFormat.class); VerticaConfiguration config = new VerticaConfiguration(job.getConfiguration()); config.setInputQuery(inpuQuery); config.setInputParams(segmentParams); } /** {@inheritDoc} */ public RecordReader<LongWritable, VerticaRecord> createRecordReader( InputSplit split, TaskAttemptContext context) throws IOException { try { return new VerticaRecordReader((VerticaInputSplit) split, context.getConfiguration()); } catch (Exception e) { throw new IOException(e); } } /** {@inheritDoc} */ public List<InputSplit> getSplits(JobContext context) throws IOException { Configuration conf = context.getConfiguration(); long numSplits = conf.getInt("mapreduce.job.maps", 1); LOG.debug("creating splits up to " + numSplits); List<InputSplit> splits = new ArrayList<InputSplit>(); int i = 0; // This is the fancy part of mapping inputs...here's how we figure out // splits // get the params query or the params VerticaConfiguration config = new VerticaConfiguration(conf); if (inputQuery == null) inputQuery = config.getInputQuery(); if (inputQuery == null) throw new IOException("Vertica input requires query defined by " + VerticaConfiguration.QUERY_PROP); if (params == null) params = config.getParamsQuery(); Collection<List<Object>> paramCollection = config.getInputParameters(); if (params != null && params.startsWith("select")) { LOG.debug("creating splits using paramsQuery :" + params); Connection conn = null; Statement stmt = null; try { conn = config.getConnection(false); stmt = conn.createStatement(); ResultSet rs = stmt.executeQuery(params); ResultSetMetaData rsmd = rs.getMetaData(); while (rs.next()) { List<Object> segmentParams = new ArrayList<Object>(); for (int j = 1; j <= rsmd.getColumnCount(); j++) { segmentParams.add(rs.getObject(j)); } splits.add(new VerticaInputSplit(inputQuery, segmentParams)); } } catch (Exception e) { throw new IOException(e); } finally { try { if (stmt != null) stmt.close(); } catch (SQLException e) { throw new IOException(e); } } } else if (params != null) { LOG.debug("creating splits using " + params + " params"); for (String strParam : params.split(",")) { strParam = strParam.trim(); if (strParam.charAt(0) == '\'' && strParam.charAt(strParam.length() - 1) == '\'') strParam = strParam.substring(1, strParam.length() - 1); List<Object> segmentParams = new ArrayList<Object>(); segmentParams.add(strParam); splits.add(new VerticaInputSplit(inputQuery, segmentParams)); } } else if (paramCollection != null) { LOG.debug("creating splits using " + paramCollection.size() + " params"); for (List<Object> segmentParams : paramCollection) { // if there are more numSplits than params we're going to introduce some // limit and offsets splits.add(new VerticaInputSplit(inputQuery, segmentParams)); } } else { LOG.debug("creating splits using limit and offset"); Connection conn = null; Statement stmt = null; long count = 0; long start = 0; long end = 0; // TODO: limit needs order by unique key // TODO: what if there are more parameters than numsplits? // prep a count(*) wrapper query and then populate the bind params for each String countQuery = "SELECT COUNT(*) FROM (\n" + inputQuery + "\n) count"; try { conn = config.getConnection(false); stmt = conn.createStatement(); ResultSet rs = stmt.executeQuery(countQuery); rs.next(); count = rs.getLong(1); } catch (Exception e) { throw new IOException(e); } finally { try { if (stmt != null) stmt.close(); } catch (SQLException e) { throw new IOException(e); } } long splitSize = count / numSplits; end = splitSize; LOG.debug("creating " + numSplits + " splits for " + count + " records"); for (i = 1; i < numSplits; i++) { splits.add(new VerticaInputSplit(inputQuery, start, end)); LOG.debug("Split(" + i + "), start:" + start + ", end:" + end); start += splitSize; end += splitSize; count -= splitSize; } if (count > 0) { splits.add(new VerticaInputSplit(inputQuery, start, start + count)); } } LOG.debug("returning " + splits.size() + " final splits"); return splits; } }