package org.test4j.module.dbfit.fixture.fit; import java.sql.CallableStatement; import java.sql.SQLException; import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; import java.util.List; import java.util.Map; import org.test4j.module.database.environment.normalise.NameNormaliser; import org.test4j.module.database.utility.DBHelper; import org.test4j.module.dbfit.environment.DbFitEnvironment; import org.test4j.module.dbfit.exception.HasMarkedException; import org.test4j.module.dbfit.fixture.test4jFixture; import org.test4j.module.dbfit.model.DbParameterAccessor; import org.test4j.module.dbfit.model.SymbolAccessQueryBinding; import org.test4j.module.dbfit.model.SymbolAccessSetBinding; import fit.Binding; import fit.Parse; public class ExecuteProcedureFixture extends Test4JFixture { private DbFitEnvironment environment; // private CallableStatement statement; private String procName; private DbParameterAccessor[] accessors; private Binding[] columnBindings; private boolean exceptionExpected = false; private boolean excNumberDefined = false; private int excNumberExpected; // public ExecuteProcedureFixture() { // this.environment = DbFactory.instance().factory();// // DbEnvironmentFactory.getDefaultEnvironment(); // } public ExecuteProcedureFixture(DbFitEnvironment dbEnvironment, String procName, int expectedErrorCode) { this.procName = procName; this.environment = dbEnvironment; this.exceptionExpected = true; this.excNumberDefined = true; this.excNumberExpected = expectedErrorCode; } public ExecuteProcedureFixture(DbFitEnvironment dbEnvironment, String procName, boolean exceptionExpected) { this.procName = procName; this.environment = dbEnvironment; this.exceptionExpected = exceptionExpected; this.excNumberDefined = false; } public ExecuteProcedureFixture(DbFitEnvironment dbEnvironment, String procName) { this(dbEnvironment, procName, false); } private class PositionComparator implements Comparator<DbParameterAccessor> { public int compare(DbParameterAccessor o1, DbParameterAccessor o2) { return (int) Math.signum(o1.getPosition() - o2.getPosition()); } } List<String> getSortedAccessorNames(DbParameterAccessor[] accessors) { DbParameterAccessor[] newacc = new DbParameterAccessor[accessors.length]; System.arraycopy(accessors, 0, newacc, 0, accessors.length); Arrays.sort(newacc, new PositionComparator()); List<String> nameList = new ArrayList<String>(); String lastName = null; for (DbParameterAccessor p : newacc) { if (lastName != p.getName()) { lastName = p.getName(); nameList.add(p.getName()); } } return nameList; } private boolean containsReturnValue(DbParameterAccessor[] accessors) { for (DbParameterAccessor ac : accessors) { if (ac.getDirection() == DbParameterAccessor.RETURN_VALUE) return true; } return false; } public CallableStatement buildCommand(String procName, DbParameterAccessor[] accessors) throws SQLException { List<String> accessorNames = getSortedAccessorNames(accessors); boolean isFunction = containsReturnValue(accessors); StringBuilder ins = new StringBuilder("{ "); if (isFunction) { ins.append("? ="); } ins.append("call ").append(procName); String comma = "("; boolean hasArguments = false; for (int i = (isFunction ? 1 : 0); i < accessorNames.size(); i++) { ins.append(comma); ins.append("?"); comma = ","; hasArguments = true; } if (hasArguments) { ins.append(")"); } ins.append("}"); CallableStatement cs = environment.connect().prepareCall(ins.toString()); for (DbParameterAccessor ac : accessors) { int realindex = accessorNames.indexOf(ac.getName()); ac.bindTo(this, cs, realindex + 1); // jdbc params are 1-based } return cs; } private Parse headerRow; public void doTable(Parse table) { this.headerRow = table.parts; try { super.doTable(table); } catch (Throwable e) { exception(headerRow, e); } } public void doRows(Parse rows) { // if table not defined as parameter, read from fixture argument; if // still not defined, read from first row if ((procName == null || procName.trim().length() == 0) && args.length > 0) { procName = args[0]; } if (rows != null) { executeStatementForEachRow(rows); } else { executeUsingHeaderRow(); } } private void executeUsingHeaderRow() { CallableStatement statement = null; try { accessors = new DbParameterAccessor[0]; statement = buildCommand(procName, accessors); if (exceptionExpected == false) { statement.execute(); } else {// execute using header row executeExpectingException(statement, headerRow); } } catch (SQLException e) { exception(headerRow, e); headerRow.parts.last().more = new Parse("td", e.getMessage(), null, null); e.printStackTrace(); } finally { DBHelper.closeStatement(statement); statement = null; } } private void executeStatementForEachRow(Parse rows) { CallableStatement statement = null; try { initParameters(rows.parts);// init parameters from the first row statement = buildCommand(procName, accessors); Parse row = rows; while ((row = row.more) != null) { runRow(statement, row); } } catch (Throwable e) { exception(rows.parts, e); } finally { DBHelper.closeStatement(statement); statement = null; } } private void initParameters(Parse headerCells) throws SQLException { Map<String, DbParameterAccessor> allParams = environment.getAllProcedureParameters(procName); if (allParams.isEmpty()) { throw new SQLException("Cannot retrieve list of parameters for " + procName + " - check spelling and access rights"); } accessors = new DbParameterAccessor[headerCells.size()]; columnBindings = new Binding[headerCells.size()]; for (int i = 0; headerCells != null; i++, headerCells = headerCells.more) { String name = headerCells.text(); String paramName = NameNormaliser.normaliseName(name); accessors[i] = allParams.get(paramName); if (accessors[i] == null) throw new SQLException("Cannot find parameter for column " + i + " name=\"" + paramName + "\""); boolean isOutput = headerCells.text().endsWith("?"); if (accessors[i].getDirection() == DbParameterAccessor.INPUT_OUTPUT) { // clone, separate into input and output accessors[i] = new DbParameterAccessor(accessors[i]); accessors[i].setDirection(isOutput ? DbParameterAccessor.OUTPUT : DbParameterAccessor.INPUT); } if (isOutput) { columnBindings[i] = new SymbolAccessQueryBinding(); } else { // sql server quirk. if output parameter is used in an input // column, then // the param should be cloned and remapped to IN/OUT if (accessors[i].getDirection() == DbParameterAccessor.OUTPUT) { accessors[i] = new DbParameterAccessor(accessors[i]); accessors[i].setDirection(DbParameterAccessor.INPUT); } columnBindings[i] = new SymbolAccessSetBinding(); } columnBindings[i].adapter = accessors[i]; } } private void runRow(CallableStatement statement, Parse row) { Parse cell = row.parts; try { statement.clearParameters(); // first set input params for (int column = 0; column < accessors.length; column++, cell = cell.more) { if (accessors[column].getDirection() == DbParameterAccessor.INPUT) { columnBindings[column].doCell(this, cell); } } if (!exceptionExpected) { statement.execute(); cell = row.parts; // next evaluate output params for (int column = 0; column < accessors.length; column++, cell = cell.more) { if (accessors[column].getDirection() == DbParameterAccessor.OUTPUT || accessors[column].getDirection() == DbParameterAccessor.RETURN_VALUE) { columnBindings[column].doCell(this, cell); } } } else { executeExpectingException(statement, row); } } catch (Throwable e) { exception(cell, e); throw new HasMarkedException(e); } } private void executeExpectingException(CallableStatement statement, Parse row) { try { statement.execute(); // no exception if we are here, mark whole row wrong(row); } catch (SQLException sqle) { if (!excNumberDefined) right(row); else { int realError = environment.getExceptionCode(sqle); if (realError == excNumberExpected) right(row); else { wrong(row); row.parts.addToBody(fit.Fixture.gray(" got error code " + realError)); } } } } }