/***************************************************************** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. ****************************************************************/ package org.apache.cayenne.test.jdbc; import javax.sql.DataSource; import java.sql.Connection; import java.sql.ParameterMetaData; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Time; import java.sql.Timestamp; import java.util.ArrayList; import java.util.List; /** * JDBC utility class for setting up and analyzing the DB data sets. DBHelper * intentionally bypasses Cayenne stack. */ public class DBHelper { protected DataSource dataSource; public DBHelper(DataSource dataSource) { this.dataSource = dataSource; } /** * Quotes a SQL identifier as appropriate for the given DB. This * implementation returns the identifier unchanged, while subclasses can * implement a custom quoting strategy. */ protected String quote(String sqlIdentifier) { return sqlIdentifier; } /** * Selects a single row. */ public Object[] select(String table, final String[] columns) throws SQLException { if (columns.length == 0) { throw new IllegalArgumentException("No columns"); } StringBuilder sql = new StringBuilder("select "); sql.append(quote(columns[0])); for (int i = 1; i < columns.length; i++) { sql.append(", ").append(quote(columns[i])); } sql.append(" from ").append(quote(table)); return new RowTemplate<Object[]>(this) { @Override Object[] readRow(ResultSet rs, String sql) throws SQLException { Object[] result = new Object[columns.length]; for (int i = 1; i <= result.length; i++) { result[i - 1] = rs.getObject(i); } return result; } }.execute(sql.toString()); } public List<Object[]> selectAll(String table, final String[] columns) throws SQLException { if (columns.length == 0) { throw new IllegalArgumentException("No columns"); } StringBuilder sql = new StringBuilder("select "); sql.append(quote(columns[0])); for (int i = 1; i < columns.length; i++) { sql.append(", ").append(quote(columns[i])); } sql.append(" from ").append(quote(table)); return new ResultSetTemplate<List<Object[]>>(this) { @Override List<Object[]> readResultSet(ResultSet rs, String sql) throws SQLException { List<Object[]> result = new ArrayList<>(); while (rs.next()) { Object[] row = new Object[columns.length]; for (int i = 1; i <= row.length; i++) { row[i - 1] = rs.getObject(i); } result.add(row); } return result; } }.execute(sql.toString()); } /** * Inserts a single row. Columns types can be null and will be determined * from ParameterMetaData in this case. The later scenario will not work if * values contains nulls and the DB is Oracle. */ public void insert(String table, String[] columns, Object[] values, int[] columnTypes) throws SQLException { if (columns.length != values.length) { throw new IllegalArgumentException("Columns and values arrays have different sizes: " + columns.length + " and " + values.length); } if (columns.length == 0) { throw new IllegalArgumentException("No columns"); } StringBuilder sql = new StringBuilder("INSERT INTO "); sql.append(quote(table)).append(" (").append(quote(columns[0])); for (int i = 1; i < columns.length; i++) { sql.append(", ").append(quote(columns[i])); } sql.append(") VALUES (?"); for (int i = 1; i < values.length; i++) { sql.append(", ?"); } sql.append(")"); try (Connection c = getConnection();) { String sqlString = sql.toString(); UtilityLogger.log(sqlString); ParameterMetaData parameters = null; try (PreparedStatement st = c.prepareStatement(sqlString);) { for (int i = 0; i < values.length; i++) { if (values[i] == null) { int type; if (columnTypes == null) { // check for the right NULL type if (parameters == null) { parameters = st.getParameterMetaData(); } type = parameters.getParameterType(i + 1); } else { type = columnTypes[i]; } st.setNull(i + 1, type); } else { if(columnTypes != null) { st.setObject(i + 1, values[i], columnTypes[i]); } else { st.setObject(i + 1, values[i]); } } } st.executeUpdate(); } c.commit(); } } public int deleteAll(String tableName) throws SQLException { return delete(tableName).execute(); } public UpdateBuilder update(String tableName) throws SQLException { return new UpdateBuilder(this, tableName); } public DeleteBuilder delete(String tableName) { return new DeleteBuilder(this, tableName); } public int getRowCount(String table) throws SQLException { String sql = "select count(*) from " + quote(table); return new RowTemplate<Integer>(this) { @Override Integer readRow(ResultSet rs, String sql) throws SQLException { return rs.getInt(1); } }.execute(sql); } public String getString(String table, String column) throws SQLException { final String sql = "select " + quote(column) + " from " + quote(table); return new RowTemplate<String>(this) { @Override String readRow(ResultSet rs, String sql) throws SQLException { return rs.getString(1); } }.execute(sql); } public Object getObject(String table, String column) throws SQLException { final String sql = "select " + quote(column) + " from " + quote(table); return new RowTemplate<Object>(this) { @Override Object readRow(ResultSet rs, String sql) throws SQLException { return rs.getObject(1); } }.execute(sql); } public byte getByte(String table, String column) throws SQLException { final String sql = "select " + quote(column) + " from " + quote(table); return new RowTemplate<Byte>(this) { @Override Byte readRow(ResultSet rs, String sql) throws SQLException { return rs.getByte(1); } }.execute(sql); } public byte[] getBytes(String table, String column) throws SQLException { final String sql = "select " + quote(column) + " from " + quote(table); return new RowTemplate<byte[]>(this) { @Override byte[] readRow(ResultSet rs, String sql) throws SQLException { return rs.getBytes(1); } }.execute(sql); } public int getInt(String table, String column) throws SQLException { final String sql = "select " + quote(column) + " from " + quote(table); return new RowTemplate<Integer>(this) { @Override Integer readRow(ResultSet rs, String sql) throws SQLException { return rs.getInt(1); } }.execute(sql); } public long getLong(String table, String column) throws SQLException { final String sql = "select " + quote(column) + " from " + quote(table); return new RowTemplate<Long>(this) { @Override Long readRow(ResultSet rs, String sql) throws SQLException { return rs.getLong(1); } }.execute(sql); } public double getDouble(String table, String column) throws SQLException { final String sql = "select " + quote(column) + " from " + quote(table); return new RowTemplate<Double>(this) { @Override Double readRow(ResultSet rs, String sql) throws SQLException { return rs.getDouble(1); } }.execute(sql); } public boolean getBoolean(String table, String column) throws SQLException { final String sql = "select " + quote(column) + " from " + quote(table); return new RowTemplate<Boolean>(this) { @Override Boolean readRow(ResultSet rs, String sql) throws SQLException { return rs.getBoolean(1); } }.execute(sql); } public java.util.Date getUtilDate(String table, String column) throws SQLException { Timestamp ts = getTimestamp(table, column); return ts != null ? new java.util.Date(ts.getTime()) : null; } public java.sql.Date getSqlDate(String table, String column) throws SQLException { final String sql = "select " + quote(column) + " from " + quote(table); return new RowTemplate<java.sql.Date>(this) { @Override java.sql.Date readRow(ResultSet rs, String sql) throws SQLException { return rs.getDate(1); } }.execute(sql); } public Time getTime(String table, String column) throws SQLException { final String sql = "select " + quote(column) + " from " + quote(table); return new RowTemplate<Time>(this) { @Override Time readRow(ResultSet rs, String sql) throws SQLException { return rs.getTime(1); } }.execute(sql); } public Timestamp getTimestamp(String table, String column) throws SQLException { final String sql = "select " + quote(column) + " from " + quote(table); return new RowTemplate<Timestamp>(this) { @Override Timestamp readRow(ResultSet rs, String sql) throws SQLException { return rs.getTimestamp(1); } }.execute(sql); } public Connection getConnection() throws SQLException { Connection connection = dataSource.getConnection(); try { connection.setAutoCommit(false); } catch (SQLException e) { try { connection.close(); } catch (SQLException ignored) { } } return connection; } }