package org.test4j.module.dbfit.environment; import java.math.BigDecimal; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.regex.Pattern; import org.test4j.module.database.environment.DBEnvironment; import org.test4j.module.database.environment.normalise.NameNormaliser; import org.test4j.module.database.utility.DBHelper; import org.test4j.module.dbfit.model.DbParameterAccessor; @SuppressWarnings({ "rawtypes" }) public class DbFitSqlServerEnvironment extends DbFitAbstractDBEnvironment { public DbFitSqlServerEnvironment(DBEnvironment dbEnvironment) { super(dbEnvironment); } public boolean supportsOuputOnInsert() { return false; } private static String paramNamePattern = "@([A-Za-z0-9_]+)"; private static Pattern paramRegex = Pattern.compile(paramNamePattern); public Pattern getParameterPattern() { return paramRegex; } protected String parseCommandText(String commandText, String[] vars) { if (vars == null || vars.length == 0) { return commandText; } String sql = commandText; for (String var : vars) { sql = sql.replace("@" + var, "?"); } return sql; } public Map<String, DbParameterAccessor> getAllColumns(String tableOrViewName) throws SQLException { String qry = " select c.[name], TYPE_NAME(c.system_type_id) as [Type], c.max_length, " + " 0 As is_output, 0 As is_cursor_ref " + " from sys.columns c " + " where c.object_id = OBJECT_ID(?) " + " order by column_id"; return readIntoParams(tableOrViewName, qry); } private Map<String, DbParameterAccessor> readIntoParams(String objname, String query) throws SQLException { if (objname.contains(".")) { String[] schemaAndName = objname.split("[\\.]", 2); objname = "[" + schemaAndName[0] + "].[" + schemaAndName[1] + "]"; } else { objname = "[" + NameNormaliser.normaliseName(objname) + "]"; } PreparedStatement dc = null; ResultSet rs = null; try { dc = this.connect().prepareStatement(query); dc.setString(1, NameNormaliser.normaliseName(objname)); rs = dc.executeQuery(); Map<String, DbParameterAccessor> allParams = new HashMap<String, DbParameterAccessor>(); int position = 0; while (rs.next()) { String paramName = rs.getString(1); if (paramName == null) paramName = ""; String dataType = rs.getString(2); // int length = rs.getInt(3); int direction = rs.getInt(4); int paramDirection; if (paramName.trim().length() == 0) paramDirection = DbParameterAccessor.RETURN_VALUE; else paramDirection = getParameterDirection(direction); DbParameterAccessor dbp = new DbParameterAccessor(paramName, paramDirection, getSqlType(dataType), getJavaClass(dataType), position++); allParams.put(NameNormaliser.normaliseName(paramName), dbp); } return allParams; } finally { DBHelper.closeResultSet(rs); rs = null; DBHelper.closeStatement(dc); dc = null; } } /** * List interface has sequential search, so using list instead of array to * map types */ private static List<String> stringTypes = Arrays.asList(new String[] { "VARCHAR", "NVARCHAR", "CHAR", "NCHAR", "TEXT", "NTEXT", "UNIQUEIDENTIFIER" }); private static List<String> intTypes = Arrays.asList(new String[] { "INT" }); private static List<String> booleanTypes = Arrays.asList(new String[] { "BIT" }); private static List<String> floatTypes = Arrays.asList(new String[] { "REAL" }); private static List<String> doubleTypes = Arrays.asList(new String[] { "FLOAT" }); private static List<String> longTypes = Arrays.asList(new String[] { "BIGINT" }); private static List<String> shortTypes = Arrays.asList(new String[] { "TINYINT", "SMALLINT" }); private static List<String> decimalTypes = Arrays .asList(new String[] { "DECIMAL", "NUMERIC", "MONEY", "SMALLMONEY" }); private static List<String> timestampTypes = Arrays .asList(new String[] { "SMALLDATETIME", "DATETIME", "TIMESTAMP" }); private static int getParameterDirection(int isOutput) { if (isOutput == 1) return DbParameterAccessor.OUTPUT; else return DbParameterAccessor.INPUT; } private static String normaliseTypeName(String dataType) { dataType = dataType.toUpperCase().trim(); int idx = dataType.indexOf(" "); if (idx >= 0) dataType = dataType.substring(0, idx); idx = dataType.indexOf("("); if (idx >= 0) dataType = dataType.substring(0, idx); return dataType; } private static int getSqlType(String dataType) { // todo:strip everything from first blank dataType = normaliseTypeName(dataType); if (stringTypes.contains(dataType)) return java.sql.Types.VARCHAR; if (decimalTypes.contains(dataType)) return java.sql.Types.NUMERIC; if (intTypes.contains(dataType)) return java.sql.Types.INTEGER; if (timestampTypes.contains(dataType)) return java.sql.Types.TIMESTAMP; if (booleanTypes.contains(dataType)) return java.sql.Types.BOOLEAN; if (floatTypes.contains(dataType)) return java.sql.Types.FLOAT; if (doubleTypes.contains(dataType)) return java.sql.Types.DOUBLE; if (longTypes.contains(dataType)) return java.sql.Types.BIGINT; if (shortTypes.contains(dataType)) return java.sql.Types.SMALLINT; throw new UnsupportedOperationException("Type " + dataType + " is not supported"); } public Class getJavaClass(String dataType) { dataType = normaliseTypeName(dataType); if (stringTypes.contains(dataType)) return String.class; if (decimalTypes.contains(dataType)) return BigDecimal.class; if (intTypes.contains(dataType)) return Integer.class; if (timestampTypes.contains(dataType)) return java.sql.Timestamp.class; if (booleanTypes.contains(dataType)) return Boolean.class; if (floatTypes.contains(dataType)) return Float.class; if (doubleTypes.contains(dataType)) return Double.class; if (longTypes.contains(dataType)) return Long.class; if (shortTypes.contains(dataType)) return Short.class; throw new UnsupportedOperationException("Type " + dataType + " is not supported"); } public Map<String, DbParameterAccessor> getAllProcedureParameters(String procName) throws SQLException { return readIntoParams(procName, "select p.[name], TYPE_NAME(p.system_type_id) as [Type], " + " p.max_length, p.is_output, p.is_cursor_ref from sys.parameters p " + " where p.object_id = OBJECT_ID(?) order by parameter_id "); } public String getFieldQuato() { return ""; } }