/*
* ModeShape (http://www.modeshape.org)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.modeshape.persistence.relational;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import org.modeshape.common.annotation.NotThreadSafe;
import org.modeshape.common.logging.Logger;
import org.modeshape.schematic.document.Bson;
import org.modeshape.schematic.document.Document;
/**
* Default implementation for the {@link Statements} interface which applies to all databases.
*
* @author Horia Chiorean (hchiorea@redhat.com)
* @since 5.0
*/
public class DefaultStatements implements Statements {
protected static final int DEFAULT_MAX_STATEMENT_PARAM_COUNT = 1000;
private static final String PLACEHOLDER_STRING = "#";
protected final Logger logger = Logger.getLogger(getClass());
private final Map<String, String> statements;
private final RelationalDbConfig config;
protected DefaultStatements( RelationalDbConfig config, Map<String, String> statements ) {
this.statements = statements;
this.config = config;
}
@Override
public Void createTable( Connection connection ) throws SQLException {
logTableInfo("Creating table {0}...");
try (PreparedStatement createStmt = connection.prepareStatement(statements.get(CREATE_TABLE))) {
if (createStmt.executeUpdate() > 0) {
logTableInfo("Table {0} created");
} else {
logTableInfo("Table {0} already exists");
}
} catch (SQLException e) {
processSQLException(CREATE_TABLE, e);
}
return null;
}
@Override
public Void dropTable( Connection connection ) throws SQLException {
logTableInfo("Dropping table {0}...");
try (PreparedStatement createStmt = connection.prepareStatement(statements.get(DELETE_TABLE))) {
if (createStmt.executeUpdate() > 0) {
logTableInfo("Table {0} dropped");
} else {
logTableInfo("Table {0} does not exist");
}
} catch (SQLException e) {
processSQLException(DELETE_TABLE, e);
}
return null;
}
protected void processSQLException(String statementId, SQLException e) throws SQLException {
// by default we just rethrow the exception as-is, but certain subclasses may want different handling
throw e;
}
@Override
public List<String> getAllIds(Connection connection) throws SQLException {
logTableInfo("Returning all ids from {0}");
try (PreparedStatement ps = connection.prepareStatement(statements.get(GET_ALL_IDS))) {
List<String> result = new ArrayList<>();
ps.setFetchSize(config.fetchSize());
try (ResultSet rs = ps.executeQuery()) {
while (rs.next()) {
result.add(rs.getString(1));
}
}
return result;
}
}
@Override
public Document getById( Connection connection, String id ) throws SQLException {
if (logger.isDebugEnabled()) {
logger.debug("Searching for entry by id {0} in {1}", id, tableName());
}
try (PreparedStatement ps = connection.prepareStatement(statements.get(GET_BY_ID))) {
ps.setString(1, id);
try (ResultSet rs = ps.executeQuery()) {
if (!rs.next()) {
return null;
}
return readDocument(rs.getBinaryStream(1));
}
}
}
@Override
public <R> List<R> load(Connection connection, Collection<String> ids, Function<Document, R> parser) throws SQLException {
if (logger.isDebugEnabled()) {
logger.debug("Loading ids {0} from {1}", ids.toString(), tableName());
}
if (ids.isEmpty()) {
return new ArrayList<>();
}
String getMultipleStatement = statements.get(GET_MULTIPLE);
String formattedStatement = formatStatementWithMultipleParams(getMultipleStatement, ids.size());
try (PreparedStatement ps = connection.prepareStatement(formattedStatement)) {
int paramIdx = 1;
for (String id : ids) {
ps.setString(paramIdx++, id);
}
try (ResultSet rs = ps.executeQuery()) {
List<R> results = new ArrayList<>();
while (rs.next()) {
Document document = readDocument(rs.getBinaryStream(1));
results.add(parser.apply(document));
}
return results;
}
}
}
private String formatStatementWithMultipleParams(String statement, int paramCount) {
String multipleSelectionClause = statements.get(MULTIPLE_SELECTION);
int maxStatementParamCount = maxStatementParamCount();
int inClauseSegments = paramCount / maxStatementParamCount;
int lastInClauseSize = paramCount % maxStatementParamCount;
StringBuilder multipleSelectionStatement = new StringBuilder();
if (inClauseSegments > 0) {
String multipleSelectionSegment = multipleSelectionClause.replace(PLACEHOLDER_STRING,
IntStream.range(0, maxStatementParamCount)
.mapToObj(nr -> "?")
.collect(Collectors.joining(",")));
IntStream.range(0, inClauseSegments).forEach(i -> {
if (multipleSelectionStatement.length() > 0) {
multipleSelectionStatement.append(" OR ");
}
multipleSelectionStatement.append(multipleSelectionSegment);
});
}
if (lastInClauseSize > 0) {
String lastSelectionSegment = multipleSelectionClause.replace(PLACEHOLDER_STRING,
IntStream.range(0, lastInClauseSize)
.mapToObj(nr -> "?")
.collect(Collectors.joining(",")));
if (multipleSelectionStatement.length() > 0) {
multipleSelectionStatement.append(" OR ");
}
multipleSelectionStatement.append(lastSelectionSegment);
}
return statement.replaceAll(PLACEHOLDER_STRING, multipleSelectionStatement.toString());
}
@Override
public boolean lockForWriting( Connection connection, List<String> ids ) throws SQLException {
if (logger.isDebugEnabled()) {
logger.debug("Attempting to lock ids {0} from {1}", ids.toString(), tableName());
}
String lockContentStatement = statements.get(LOCK_CONTENT);
if (ids.isEmpty()) {
return false;
}
String formattedStatement = formatStatementWithMultipleParams(lockContentStatement, ids.size());
try (PreparedStatement ps = connection.prepareStatement(formattedStatement)) {
int paramIdx = 1;
for (String id : ids) {
ps.setString(paramIdx++, id);
}
try (ResultSet rs = ps.executeQuery()) {
// any failed lock should result in a timeout being eventually thrown by the DB
// ModeShape will frequently try to lock new nodes before inserting them, so it's important that this method
// returns 'true' for those nodes
logger.debug("successfully locked ids");
return true;
} catch (SQLException e) {
logger.debug(e, " cannot lock ids");
return false;
}
}
}
@Override
public DefaultBatchUpdate batchUpdate( Connection connection ) {
return new DefaultBatchUpdate(connection);
}
@Override
public boolean exists( Connection connection, String id ) throws SQLException {
if (logger.isDebugEnabled()) {
logger.debug("Checking if the content with ID {0} exists in {1}", id, tableName());
}
try (PreparedStatement ps = connection.prepareStatement(statements.get(CONTENT_EXISTS))) {
ps.setString(1, id);
ResultSet rs = ps.executeQuery();
return rs.next();
}
}
@Override
public Void removeAll( Connection connection ) throws SQLException {
logTableInfo("Removing all content from {0}");
try (PreparedStatement ps = connection.prepareStatement(statements.get(REMOVE_ALL_CONTENT))) {
ps.executeUpdate();
}
return null;
}
protected int maxStatementParamCount() {
return DEFAULT_MAX_STATEMENT_PARAM_COUNT;
}
protected void logTableInfo( String message ) {
if (logger.isDebugEnabled()) {
logger.debug(message, tableName());
}
}
protected String tableName() {
return config.tableName();
}
protected Document readDocument(InputStream is) {
try (InputStream contentStream = config.compress() ? new GZIPInputStream(is) : is) {
return Bson.read(contentStream);
} catch (IOException e) {
throw new RelationalProviderException(e);
}
}
protected byte[] writeDocument(Document content) {
try {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
try (OutputStream out = config.compress() ? new GZIPOutputStream(bos) : bos) {
Bson.write(content, out);
}
return bos.toByteArray();
} catch (IOException e) {
throw new RelationalProviderException(e);
}
}
@NotThreadSafe
protected class DefaultBatchUpdate implements BatchUpdate{
private final Connection connection;
protected DefaultBatchUpdate( Connection connection ) {
this.connection = connection;
}
@Override
public void insert( Map<String, Document> documentsById ) throws SQLException {
if (documentsById.isEmpty()) {
return;
}
String sql = statements.get(INSERT_CONTENT);
PreparedStatement insert = connection.prepareStatement(sql);
documentsById.forEach(( id, document ) -> {
if (logger.isDebugEnabled()) {
logger.debug("adding batch statement: {0}", sql.replaceFirst("\\?", id));
}
insertDocument(insert, id, document);
});
insert.executeBatch();
}
protected void insertDocument(PreparedStatement statement, String id, Document document) {
try {
statement.setString(1, id);
byte[] content = writeDocument(document);
statement.setBytes(2, content);
statement.addBatch();
} catch (SQLException e) {
throw new RelationalProviderException(e);
}
}
@Override
public void update( Map<String, Document> documentsById ) throws SQLException {
if (documentsById.isEmpty()) {
return;
}
String sql = statements.get(UPDATE_CONTENT);
PreparedStatement update = connection.prepareStatement(sql);
documentsById.forEach(( id, document ) -> {
if (logger.isDebugEnabled()) {
logger.debug("adding batch statement: {0}", sql.replaceFirst(" ID.*=.*\\?", " ID = " + id));
}
updateDocument(update, id, document);
});
update.executeBatch();
}
protected void updateDocument(PreparedStatement statement, String id, Document document) {
try {
byte[] content = writeDocument(document);
statement.setBytes(1, content);
statement.setString(2, id);
statement.addBatch();
} catch (SQLException e) {
throw new RelationalProviderException(e);
}
}
@Override
public void remove( List<String> ids ) throws SQLException {
if (ids.isEmpty()) {
return;
}
String removeStatement = statements.get(REMOVE_CONTENT);
String formattedStatement = formatStatementWithMultipleParams(removeStatement, ids.size());
if (logger.isDebugEnabled()) {
logger.debug("running statement: {0}", formattedStatement);
}
try (PreparedStatement remove = connection.prepareStatement(formattedStatement)) {
int paramIdx = 1;
for (String id : ids) {
remove.setString(paramIdx++, id);
}
remove.executeUpdate();
}
}
}
}