/******************************************************************************* * Copyright 2016 Observational Health Data Sciences and Informatics * * This file is part of WhiteRabbit * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ******************************************************************************/ package org.ohdsi.databases; import java.sql.BatchUpdateException; import java.sql.Connection; import java.sql.DatabaseMetaData; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Statement; import java.sql.Types; import java.text.DecimalFormat; import java.util.ArrayList; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Set; import org.ohdsi.utilities.SimpleCounter; import org.ohdsi.utilities.StringUtilities; import org.ohdsi.utilities.files.Row; import org.ohdsi.utilities.files.WriteCSVFileWithHeader; public class RichConnection { public static int INSERT_BATCH_SIZE = 100000; private Connection connection; private boolean verbose = false; private static DecimalFormat decimalFormat = new DecimalFormat("#.#"); private DbType dbType; public RichConnection(String server, String domain, String user, String password, DbType dbType) { this.connection = DBConnector.connect(server, domain, user, password, dbType); this.dbType = dbType; } /** * Execute the given SQL statement. * * @param sql */ public void execute(String sql) { Statement statement = null; try { if (sql.length() == 0) return; statement = connection.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); for (String subQuery : sql.split(";")) { if (verbose) { String abbrSQL = subQuery.replace('\n', ' ').replace('\t', ' ').trim(); if (abbrSQL.length() > 100) abbrSQL = abbrSQL.substring(0, 100).trim() + "..."; System.out.println("Adding query to batch: " + abbrSQL); } statement.addBatch(subQuery); } long start = System.currentTimeMillis(); if (verbose) System.out.println("Executing batch"); statement.executeBatch(); if (verbose) outputQueryStats(statement, System.currentTimeMillis() - start); } catch (SQLException e) { System.err.println(sql); e.printStackTrace(); } finally { if (statement != null) { try { statement.close(); } catch (SQLException e) { // TODO Auto-generated catch block System.err.println(e.getMessage()); } } } } private void outputQueryStats(Statement statement, long ms) throws SQLException { Throwable warning = statement.getWarnings(); if (warning != null) System.out.println("- SERVER: " + warning.getMessage()); String timeString; if (ms < 1000) timeString = ms + " ms"; else if (ms < 60000) timeString = decimalFormat.format(ms / 1000d) + " seconds"; else if (ms < 3600000) timeString = decimalFormat.format(ms / 60000d) + " minutes"; else timeString = decimalFormat.format(ms / 3600000d) + " hours"; System.out.println("- Query completed in " + timeString); } /** * Query the database using the provided SQL statement. * * @param sql * @return */ public QueryResult query(String sql) { return new QueryResult(sql); } /** * Switch the database to use. * * @param database */ public void use(String database) { if (database == null) return; if (dbType == DbType.ORACLE) execute("ALTER SESSION SET current_schema = " + database); else if (dbType == DbType.POSTGRESQL || dbType == DbType.REDSHIFT) execute("SET search_path TO " + database); else if (dbType == DbType.MSACCESS) ; else execute("USE " + database); } public List<String> getTableNames(String database) { List<String> names = new ArrayList<String>(); String query = null; if (dbType == DbType.MYSQL) { query = "SHOW TABLES IN " + database; } else if (dbType == DbType.MSSQL) { query = "SELECT name FROM " + database + ".sys.tables ORDER BY name"; } else if (dbType == DbType.ORACLE) { query = "SELECT table_name FROM all_tables WHERE owner='" + database.toUpperCase() + "'"; } else if (dbType == DbType.POSTGRESQL || dbType == DbType.REDSHIFT) { query = "SELECT table_name FROM information_schema.tables WHERE table_schema = '" + database.toLowerCase() + "' ORDER BY table_name"; } else if (dbType == DbType.MSACCESS) { query = "SELECT Name FROM sys.MSysObjects WHERE Type=1 AND Flags=0;"; } for (Row row : query(query)) names.add(row.get(row.getFieldNames().get(0))); return names; } public List<String> getFieldNames(String table) { List<String> names = new ArrayList<String>(); if (dbType == DbType.MSSQL) { for (Row row : query("SELECT name FROM syscolumns WHERE id=OBJECT_ID('" + table + "')")) names.add(row.get("name")); } else if (dbType == DbType.MYSQL) for (Row row : query("SHOW COLUMNS FROM " + table)) names.add(row.get("COLUMN_NAME")); else throw new RuntimeException("DB type not supported"); return names; } public ResultSet getMsAccessFieldNames(String table){ if(dbType == DbType.MSACCESS){ try { DatabaseMetaData metadata = connection.getMetaData(); return metadata.getColumns(null, null, table, null); } catch (SQLException e) { throw new RuntimeException(e.getMessage()); } }else throw new RuntimeException("DB is not of type MS Access"); } /** * Returns the row count of the specified table. * * @param tableName * @return */ public long getTableSize(String tableName) { QueryResult qr = null; Long returnVal = null; if (dbType == DbType.MSSQL || dbType == DbType.MSACCESS) qr = query("SELECT COUNT(*) FROM [" + tableName + "];"); //return Long.parseLong(query("SELECT COUNT(*) FROM [" + tableName + "];").iterator().next().getCells().get(0)); else qr = query("SELECT COUNT(*) FROM " + tableName + ";"); // return Long.parseLong(query("SELECT COUNT(*) FROM " + tableName + ";").iterator().next().getCells().get(0)); // Obtain the value and close the connection try { returnVal = Long.parseLong(query("SELECT COUNT(*) FROM " + tableName + ";").iterator().next().getCells().get(0)); } catch (Exception e) { throw new RuntimeException(e); } finally { if (qr != null) { qr.close(); } } return returnVal; } /** * Close the connection to the database. */ public void close() { try { connection.close(); } catch (SQLException e) { e.printStackTrace(); } } public boolean isVerbose() { return verbose; } public void setVerbose(boolean verbose) { this.verbose = verbose; } public class QueryResult implements Iterable<Row> { private String sql; private List<DBRowIterator> iterators = new ArrayList<DBRowIterator>(); public QueryResult(String sql) { this.sql = sql; } @Override public Iterator<Row> iterator() { DBRowIterator iterator = new DBRowIterator(sql); iterators.add(iterator); return iterator; } public void close() { for (DBRowIterator iterator : iterators) { iterator.close(); } } } /** * Writes the results of a query to the specified file in CSV format. * * @param queryResult * @param filename */ public void writeToFile(QueryResult queryResult, String filename) { WriteCSVFileWithHeader out = new WriteCSVFileWithHeader(filename); for (Row row : queryResult) out.write(row); out.close(); } /** * Inserts the rows into a table in the database. * * @param iterator * @param tableName * @param create * If true, the data format is determined based on the first batch of rows and used to create the table structure. */ public void insertIntoTable(Iterator<Row> iterator, String table, boolean create) { List<Row> batch = new ArrayList<Row>(INSERT_BATCH_SIZE); boolean first = true; SimpleCounter counter = new SimpleCounter(1000000, true); while (iterator.hasNext()) { if (batch.size() == INSERT_BATCH_SIZE) { if (first && create) createTable(table, batch); insert(table, batch); batch.clear(); first = false; } batch.add(iterator.next()); counter.count(); } if (batch.size() != 0) { if (first && create) createTable(table, batch); insert(table, batch); } } private void insert(String tableName, List<Row> rows) { List<String> columns = null; columns = rows.get(0).getFieldNames(); for (int i = 0; i < columns.size(); i++) columns.set(i, columnNameToSqlName(columns.get(i))); String sql = "INSERT INTO " + tableName; sql = sql + " (" + StringUtilities.join(columns, ",") + ")"; sql = sql + " VALUES (?"; for (int i = 1; i < columns.size(); i++) sql = sql + ",?"; sql = sql + ")"; try { connection.setAutoCommit(false); PreparedStatement statement = connection.prepareStatement(sql); for (Row row : rows) { for (int i = 0; i < columns.size(); i++) { String value = row.get(columns.get(i)); if (value == null) System.out.println(row.toString()); if (value.length() == 0) value = null; // System.out.println(value); if (dbType == DbType.POSTGRESQL || dbType == DbType.REDSHIFT) // PostgreSQL does not allow unspecified types statement.setObject(i + 1, value, Types.OTHER); else if (dbType == DbType.ORACLE) { if (isDate(value)) { // System.out.println(value); statement.setDate(i + 1, java.sql.Date.valueOf(value)); } else statement.setString(i + 1, value); } else statement.setString(i + 1, value); } statement.addBatch(); } statement.executeBatch(); connection.commit(); statement.close(); connection.setAutoCommit(true); connection.clearWarnings(); } catch (SQLException e) { e.printStackTrace(); if (e instanceof BatchUpdateException) { System.err.println(((BatchUpdateException) e).getNextException().getMessage()); } } } private static boolean isDate(String string) { if (string != null && string.length() == 10 && string.charAt(4) == '-' && string.charAt(7) == '-') try { int year = Integer.parseInt(string.substring(0, 4)); if (year < 1700 || year > 2200) return false; int month = Integer.parseInt(string.substring(5, 7)); if (month < 1 || month > 12) return false; int day = Integer.parseInt(string.substring(8, 10)); if (day < 1 || day > 31) return false; return true; } catch (Exception e) { return false; } return false; } private Set<String> createTable(String tableName, List<Row> rows) { Set<String> numericFields = new HashSet<String>(); Row firstRow = rows.get(0); List<FieldInfo> fields = new ArrayList<FieldInfo>(rows.size()); for (String field : firstRow.getFieldNames()) fields.add(new FieldInfo(field)); for (Row row : rows) { for (FieldInfo fieldInfo : fields) { String value = row.get(fieldInfo.name); if (fieldInfo.isNumeric && !StringUtilities.isInteger(value)) fieldInfo.isNumeric = false; if (value.length() > fieldInfo.maxLength) fieldInfo.maxLength = value.length(); } } StringBuilder sql = new StringBuilder(); sql.append("CREATE TABLE " + tableName + " (\n"); for (FieldInfo fieldInfo : fields) { sql.append(" " + fieldInfo.toString() + ",\n"); if (fieldInfo.isNumeric) numericFields.add(fieldInfo.name); } sql.append(");"); execute(sql.toString()); return numericFields; } private String columnNameToSqlName(String name) { return name.replaceAll(" ", "_").replace("-", "_").replace(",", "_").replaceAll("_+", "_"); } private class FieldInfo { public String name; public boolean isNumeric = true; public int maxLength = 0; public FieldInfo(String name) { this.name = name; } public String toString() { if (dbType == DbType.MYSQL) { if (isNumeric) return columnNameToSqlName(name) + " int(" + maxLength + ")"; else if (maxLength > 255) return columnNameToSqlName(name) + " text"; else return columnNameToSqlName(name) + " varchar(255)"; } else if (dbType == DbType.MSSQL) { if (isNumeric) { if (maxLength < 10) return columnNameToSqlName(name) + " int"; else return columnNameToSqlName(name) + " bigint"; } else if (maxLength > 255) return columnNameToSqlName(name) + " varchar(max)"; else return columnNameToSqlName(name) + " varchar(255)"; } else throw new RuntimeException("Create table syntax not specified for type " + dbType); } } private class DBRowIterator implements Iterator<Row> { private ResultSet resultSet; private boolean hasNext; private Set<String> columnNames = new HashSet<String>(); public DBRowIterator(String sql) { Statement statement = null; try { sql.trim(); if (sql.endsWith(";")) sql = sql.substring(0, sql.length() - 1); if (verbose) { String abbrSQL = sql.replace('\n', ' ').replace('\t', ' ').trim(); if (abbrSQL.length() > 100) abbrSQL = abbrSQL.substring(0, 100).trim() + "..."; System.out.println("Executing query: " + abbrSQL); } long start = System.currentTimeMillis(); statement = connection.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); resultSet = statement.executeQuery(sql.toString()); hasNext = resultSet.next(); if (verbose) outputQueryStats(statement, System.currentTimeMillis() - start); } catch (SQLException e) { System.err.println(sql.toString()); System.err.println(e.getMessage()); throw new RuntimeException(e); } } public void close() { if (resultSet != null) { try { resultSet.close(); } catch (SQLException e) { e.printStackTrace(); } resultSet = null; hasNext = false; } } @Override public boolean hasNext() { return hasNext; } @Override public Row next() { try { Row row = new Row(); ResultSetMetaData metaData; metaData = resultSet.getMetaData(); columnNames.clear(); for (int i = 1; i < metaData.getColumnCount() + 1; i++) { String columnName = metaData.getColumnName(i); if (columnNames.add(columnName)) { String value = resultSet.getString(i); if (value == null) value = ""; row.add(columnName, value.replace(" 00:00:00", "")); } } hasNext = resultSet.next(); if (!hasNext) { resultSet.close(); resultSet = null; } return row; } catch (SQLException e) { e.printStackTrace(); throw new RuntimeException(e); } } @Override public void remove() { } } }