package io.vivarium.db; import java.lang.reflect.Array; import java.sql.Connection; import java.sql.DriverManager; import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Statement; import java.util.Arrays; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Properties; import com.google.common.base.Joiner; import com.google.common.primitives.Ints; import com.google.common.primitives.Longs; import io.vivarium.util.Reflection; import io.vivarium.util.UUID; import io.vivarium.util.Version; public class DatabaseUtils { public static Connection createDatabaseConnection(String databaseName, String username, String password) throws SQLException { String url = "jdbc:postgresql://localhost/" + databaseName; Properties props = new Properties(); props.setProperty("user", username); props.setProperty("password", password); props.setProperty("stringtype", "unspecified"); Connection dbConnection = DriverManager.getConnection(url, props); dbConnection.setAutoCommit(false); return dbConnection; } public static List<Map<String, Object>> select(Connection connection, String tableName, Optional<WhereCondition> condition) throws SQLException { List<Map<String, Object>> results = new LinkedList<>(); try (Statement queryStatement = connection.createStatement()) { // Build the select query StringBuilder selectStringBuilder = new StringBuilder(); selectStringBuilder.append("SELECT * FROM "); selectStringBuilder.append(tableName); if (condition.isPresent()) { selectStringBuilder.append(" WHERE "); selectStringBuilder.append(condition.get().toString()); } // Execute the select try (ResultSet resultSet = queryStatement.executeQuery(selectStringBuilder.toString())) { // Build column list ResultSetMetaData resultMetaData = resultSet.getMetaData(); int columnCount = resultMetaData.getColumnCount(); LinkedList<String> columnNames = new LinkedList<>(); for (int i = 1; i <= columnCount; i++) { columnNames.add(resultMetaData.getColumnName(i)); } // Build relation objects while (resultSet.next()) { Map<String, Object> relation = new HashMap<>(); for (String columnName : columnNames) { relation.put(columnName, resultSet.getObject(columnName)); } results.add(relation); } } } return results; } public static void upsert(Connection connection, String tableName, Map<String, Object> relation, List<String> keyColumns) throws SQLException { try (Statement sqlStatement = connection.createStatement()) { // Build lists for all columns and non-key columns for streaming over while we build the the SQL statements. List<String> allColumns = new LinkedList<>(); allColumns.addAll(relation.keySet()); List<String> nonKeyColumns = new LinkedList<>(); for (String columnName : relation.keySet()) { if (!keyColumns.contains(columnName)) { nonKeyColumns.add(columnName); } } // The values to actually use in the SQL. SQL null is fine for null values, and we can call toString() on // primitives directly, but String based values need to be wrapped with single quotes for sql. Map<String, String> sqlStrings = new HashMap<>(); for (String columnName : allColumns) { sqlStrings.put(columnName, toSqlString(relation.get(columnName))); } // Builds an update string in the form // UPDATE table // SET key1=value1, key2=value2 // WHERE id_key = id_value; StringBuilder updateStringBuilder = new StringBuilder(); updateStringBuilder.append("UPDATE "); updateStringBuilder.append(tableName); updateStringBuilder.append(" SET "); updateStringBuilder.append(Joiner.on(", ") .join(nonKeyColumns.stream().map(i -> String.format("%s=%s", i, sqlStrings.get(i))).iterator())); updateStringBuilder.append(" WHERE "); updateStringBuilder.append(Joiner.on(", ") .join(keyColumns.stream().map(i -> String.format("%s=%s", i, sqlStrings.get(i))).iterator())); updateStringBuilder.append(";"); // Builds an insert string in the form // INSERT INTO table (key1, key2, id_key) // SELECT value1, value2, id_value // WHERE NOT EXISTS (SELECT 1 FROM table WHERE id_key = id_value); StringBuilder insertStringBuilder = new StringBuilder(); insertStringBuilder.append("INSERT INTO "); insertStringBuilder.append(tableName); insertStringBuilder.append(" ("); insertStringBuilder.append(Joiner.on(", ").join(allColumns)); insertStringBuilder.append(") SELECT "); insertStringBuilder.append(Joiner.on(", ") .join(allColumns.stream().map(i -> String.format("%s", sqlStrings.get(i))).iterator())); insertStringBuilder.append(" WHERE NOT EXISTS (SELECT 1 FROM "); insertStringBuilder.append(tableName); insertStringBuilder.append(" WHERE "); insertStringBuilder.append(Joiner.on(", ") .join(keyColumns.stream().map(i -> String.format("%s=%s", i, sqlStrings.get(i))).iterator())); insertStringBuilder.append(");"); System.out.println(updateStringBuilder.toString().substring(0, Math.min(updateStringBuilder.toString().length(), 500))); System.out.println(insertStringBuilder.toString().substring(0, Math.min(insertStringBuilder.toString().length(), 500))); // Run the upsert statements. sqlStatement.execute(updateStringBuilder.toString()); sqlStatement.execute(insertStringBuilder.toString()); } } static String toSqlString(Object object) { // Quick exit for nulls if (object == null) { return "null"; } // Quick recur for Optionals if (object.getClass() == Optional.class) { if (((Optional<?>) object).isPresent()) { return toSqlString(((Optional<?>) object).get()); } else { return toSqlString(null); } } // Type conversion if required if (object.getClass() == Version.class) { // DB stores Version objects as arrays, so convert this to an array for easy encoding. object = ((Version) object).toArray(); } // Generate the string if (Reflection.isPrimitive(object.getClass())) { return object.toString(); } else if (object.getClass().isArray()) { StringBuilder arrayString = new StringBuilder(); arrayString.append("'{"); List<String> elements = new LinkedList<>(); for (int i = 0; i < Array.getLength(object); i++) { elements.add(toSqlString(Array.get(object, i))); } arrayString.append(Joiner.on(", ").join(elements)); arrayString.append("}'"); return arrayString.toString(); } else if (object.getClass() == UUID.class) { return '\'' + object.toString() + "\'::uuid"; } else { return '\'' + object.toString() + '\''; } } public static void upsert(Connection connection, String tableName, List<Map<String, Object>> relations, List<String> keyColumns) throws SQLException { for (Map<String, Object> relation : relations) { upsert(connection, tableName, relation, keyColumns); } } /** * Inserts and deletes entries from a junction table until a select against the provided keyColumn returns the * exactly the same entries as the provided relations. * * @param connection * The active connection to the database to update * @param tableName * The name of the table. The named table should probably be a junction table. * @param relations * A list of relations that should be in the junction table after completion. * @param keyColumn * The column name to select against. * @param keyValue * The column name to select against. * @throws SQLException */ public static void updateJunctionTable(Connection connection, String tableName, List<Map<String, Object>> relations, String keyColumn, Object keyValue) throws SQLException { boolean deleteOnly = false; // Build lists for all columns and non-key columns for streaming over while we build the the SQL statements. List<String> allColumns = new LinkedList<>(); List<String> nonKeyColumns = new LinkedList<>(); if (relations.size() == 0) { deleteOnly = true; allColumns.add(keyColumn); } else { Map<String, Object> sampleRelation = relations.get(0); // Assumes that all relations define the same keys. allColumns.addAll(sampleRelation.keySet()); for (String columnName : sampleRelation.keySet()) { if (!keyColumn.equals(columnName)) { nonKeyColumns.add(columnName); } } } try (Statement sqlStatement = connection.createStatement()) { // Builds an delete string in the form // DELETE FROM tableName // WHERE keyColumn = keyValue // AND column1 NOT IN(relations.get(0).get(column1), relations.get(1).get(column1)); StringBuilder deleteStringBuilder = new StringBuilder(); deleteStringBuilder.append("DELETE FROM "); deleteStringBuilder.append(tableName); deleteStringBuilder.append(" WHERE "); deleteStringBuilder.append(keyColumn); deleteStringBuilder.append('='); deleteStringBuilder.append(toSqlString(keyValue)); for (String nonKeyColumn : nonKeyColumns) { deleteStringBuilder.append(" AND "); deleteStringBuilder.append(nonKeyColumn); deleteStringBuilder.append(" NOT IN("); deleteStringBuilder.append( Joiner.on(", ").join(relations.stream().map(i -> toSqlString(i.get(nonKeyColumn))).iterator())); deleteStringBuilder.append(")"); } StringBuilder insertStringBuilder = new StringBuilder(); insertStringBuilder.append("INSERT INTO "); insertStringBuilder.append(tableName); insertStringBuilder.append(" ("); insertStringBuilder.append(Joiner.on(", ").join(allColumns)); insertStringBuilder.append(") "); boolean firstSelect = true; for (Map<String, Object> relation : relations) { if (!firstSelect) { insertStringBuilder.append(" UNION "); } else { firstSelect = false; } insertStringBuilder.append("SELECT "); insertStringBuilder.append(Joiner.on(", ").join( allColumns.stream().map(i -> String.format("%s", toSqlString(relation.get(i)))).iterator())); insertStringBuilder.append(" WHERE NOT EXISTS (SELECT 1 FROM "); insertStringBuilder.append(tableName); insertStringBuilder.append(" WHERE "); insertStringBuilder.append(Joiner.on(" AND ").join(allColumns.stream() .map(i -> String.format("%s=%s", i, toSqlString(relation.get(i)))).iterator())); insertStringBuilder.append(')'); } insertStringBuilder.append(';'); System.out.println(deleteStringBuilder.toString().substring(0, Math.min(deleteStringBuilder.toString().length(), 500))); // Run the delete & insert statements. sqlStatement.execute(deleteStringBuilder.toString()); if (!deleteOnly) { System.out.println(insertStringBuilder.toString().substring(0, Math.min(insertStringBuilder.toString().length(), 500))); sqlStatement.execute(insertStringBuilder.toString()); } } } public static Object toPrimitiveArray(java.sql.Array array, Class<?> clazz) throws SQLException { if (clazz == long.class) { Long[] objectArray = (Long[]) array.getArray(); List<Long> list = Arrays.asList(objectArray); return Longs.toArray(list); } else if (clazz == int.class) { Integer[] objectArray = (Integer[]) array.getArray(); List<Integer> list = Arrays.asList(objectArray); return Ints.toArray(list); } else { throw new IllegalArgumentException("Unable to convert to " + clazz + " array"); } } }