package dbfit.fixture;
import dbfit.api.DBEnvironment;
import dbfit.api.DbEnvironmentFactory;
import dbfit.util.DbParameterAccessor;
import dbfit.util.DbParameterAccessorTypeAdapter;
import dbfit.util.NameNormaliser;
import dbfit.util.SymbolAccessSetBinding;
import fit.Binding;
import fit.Parse;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import dbfit.util.Direction;
public class Update extends fit.Fixture {
private DBEnvironment environment;
private StatementExecution statement;
private String tableName;
private Binding[] columnBindings;
private DbParameterAccessor[] updateAccessors;
private DbParameterAccessor[] selectAccessors;
public Update() {
this.environment = DbEnvironmentFactory.getDefaultEnvironment();
}
public Update(DBEnvironment dbEnvironment) {
this.environment = dbEnvironment;
}
public Update(DBEnvironment dbEnvironment, String tableName) {
this.tableName = tableName;
this.environment = dbEnvironment;
}
private StatementExecution buildUpdateCommand() throws SQLException {
if (updateAccessors.length == 0) {
throw new Error("Update fixture must have at least one field to update. Have you forgotten = after the column name?");
}
StringBuilder s = new StringBuilder("update ").append(tableName).append(" set ");
for (int i = 0; i < updateAccessors.length; i++) {
if (i > 0) {
s.append(", ");
}
s.append(updateAccessors[i].getName()).append("=").append("?");
}
s.append(" where ");
for (int i = 0; i < selectAccessors.length; i++) {
if (i > 0) {
s.append(" and ");
}
s.append(selectAccessors[i].getName()).append("=").append("?");
}
StatementExecution cs =
environment.createStatementExecution(environment.getConnection().prepareStatement(s.toString()));
for (int i = 0; i < updateAccessors.length; i++) {
updateAccessors[i].bindTo(cs, i + 1);
}
for (int j = 0; j < selectAccessors.length; j++) {
selectAccessors[j].bindTo(cs, j + updateAccessors.length + 1);
}
return cs;
}
public void doRows(Parse rows) {
// if table not defined as parameter, read from fixture argument; if still not defined, read from first row
if ((tableName == null || tableName.trim().length() == 0) && args.length > 0) {
tableName = args[0];
} else if (tableName == null) {
tableName = rows.parts.text();
rows = rows.more;
}
try {
initParameters(rows.parts); //init parameters from the first row
try (StatementExecution st = buildUpdateCommand()) {
statement = st;
Parse row = rows;
while ((row = row.more) != null) {
runRow(row);
}
}
} catch (Throwable e) {
e.printStackTrace();
exception(rows.parts, e);
}
}
private void initParameters(Parse headerCells) throws SQLException {
Map<String, DbParameterAccessor> allParams =
environment.getAllColumns(tableName);
if (allParams.isEmpty()) {
throw new SQLException("Cannot retrieve list of columns for " + tableName + " - check spelling and access rights");
}
columnBindings = new Binding[headerCells.size()];
List<DbParameterAccessor> selectAcc = new ArrayList<DbParameterAccessor>();
List<DbParameterAccessor> updateAcc = new ArrayList<DbParameterAccessor>();
for (int i = 0; headerCells != null; i++, headerCells = headerCells.more) {
String name = headerCells.text();
String paramName = NameNormaliser.normaliseName(name);
//need to clone db param accessors here because same column may be in the update and select part
DbParameterAccessor orig = allParams.get(paramName);
if (orig == null) {
wrong(headerCells);
throw new SQLException("Cannot find column " + paramName);
}
//clone parameter because there may be multiple usages of the same column
DbParameterAccessor acc = orig.clone();
acc.setDirection(Direction.INPUT);
if (headerCells.text().endsWith("=")) {
updateAcc.add(acc);
} else {
selectAcc.add(acc);
}
columnBindings[i] = new SymbolAccessSetBinding();
columnBindings[i].adapter = new DbParameterAccessorTypeAdapter(acc, this);
}
// weird jdk syntax, method param is the type of array.
selectAccessors = selectAcc.toArray(new DbParameterAccessor[0]);
updateAccessors = updateAcc.toArray(new DbParameterAccessor[0]);
}
private void runRow(Parse row) throws Throwable {
try {
Parse cell = row.parts;
//first set input params
for (int column = 0; column < columnBindings.length; column++, cell = cell.more) {
columnBindings[column].doCell(this, cell);
}
statement.run();
} catch (SQLException sqle) {
sqle.printStackTrace();
exception(row,sqle);
row.parts.last().more = new Parse("td", sqle.getMessage(), null, null);
}
}
}