package org.jentrata.ebms.messaging.internal.sql;
import org.apache.commons.io.IOUtils;
import org.jentrata.ebms.MessageStatusType;
import org.jentrata.ebms.MessageType;
import org.jentrata.ebms.messaging.Message;
import org.jentrata.ebms.messaging.MessageStoreException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.sql.DataSource;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.sql.Timestamp;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* TODO: add description
*
* @author aaronwalker
*/
public abstract class AbstractRepositoryManager implements RepositoryManager {
protected final Logger LOG = LoggerFactory.getLogger(getClass());
protected DataSource dataSource;
public AbstractRepositoryManager(DataSource dataSource) {
this.dataSource = dataSource;
}
@Override
public void createTablesIfNotExists() {
try(Connection connection = dataSource.getConnection()) {
try(Statement stmt = connection.createStatement()) {
for(String sql : getCreateSQL()) {
int result = stmt.executeUpdate(sql);
if(result > 0) {
LOG.info("Message Store tables successfully created");
}
}
}
} catch (Exception e) {
throw new MessageStoreException("unable to create/check database tables:" + e,e);
}
}
@Override
public boolean isDuplicate(String messageId, String messageDirection) {
Map<String,Object> fields = new HashMap<>();
fields.put("message_id",messageId);
fields.put("message_box",messageDirection);
List<Message> messages = selectMessageBy(fields);
return !messages.isEmpty();
}
@Override
public void insertIntoRepository(String messageId, String contentType, String messageDirection, long contentLength, InputStream content, String duplicateMessageId) {
try(Connection connection = dataSource.getConnection()) {
String sql = getInsertSQL();
try(PreparedStatement stmt = connection.prepareStatement(sql)) {
stmt.setString(1,messageId);
stmt.setString(2,contentType);
stmt.setTimestamp(3, new Timestamp(new Date().getTime()));
stmt.setString(4,messageDirection);
ByteArrayOutputStream bos = new ByteArrayOutputStream();
IOUtils.copy(content,bos);
stmt.setBytes(5, bos.toByteArray());
int result = stmt.executeUpdate();
if(result != 1) {
throw new MessageStoreException("failed to write message to store");
}
}
} catch (SQLException|IOException e) {
throw new MessageStoreException("failed to write message to store:" + e,e);
}
}
@Override
public void insertMessage(String messageId, String messageDirection, MessageType messageType, String cpaId, String conversationId, String refMessageID) {
try(Connection connection = dataSource.getConnection()) {
String sql = getMessageInsertSQL();
try(PreparedStatement stmt = connection.prepareStatement(sql)) {
stmt.setString(1,messageId);
stmt.setString(2,messageDirection);
stmt.setString(3,messageType.name());
stmt.setString(4,cpaId);
stmt.setString(5,conversationId);
stmt.setString(6,refMessageID);
stmt.setTimestamp(7, new Timestamp(new Date().getTime()));
int result = stmt.executeUpdate();
if(result != 1) {
throw new MessageStoreException("failed to insert message " + messageId);
}
}
} catch (SQLException ex) {
throw new MessageStoreException("failed to insert message " + messageId);
}
}
@Override
public void updateMessage(String messageId, String messageDirection, MessageStatusType status, String statusDescription) {
try(Connection connection = dataSource.getConnection()) {
String sql = getMessageUpdateSQL();
try(PreparedStatement stmt = connection.prepareStatement(sql)) {
stmt.setString(1,status.name());
stmt.setString(2,statusDescription);
stmt.setString(3,messageId);
stmt.setString(4,messageDirection);
int result = stmt.executeUpdate();
if(result != 1) {
LOG.warn("failed to update message " + messageId + " to status " + status);
}
}
} catch (SQLException ex) {
LOG.warn("failed to update message " + messageId + " to status " + status);
LOG.debug("",ex);
}
}
@Override
public List<Message> selectMessageBy(String columnName, String value) {
try(Connection connection = dataSource.getConnection()) {
Map<String,Object> fields = new HashMap<>();
fields.put(columnName,value);
String sql = getMessageSelectSQL(fields);
try(PreparedStatement stmt = connection.prepareStatement(sql)) {
stmt.setString(1,value);
ResultSet result = stmt.executeQuery();
return JDBCMessageMapper.getMessage(result);
}
} catch (SQLException ex) {
LOG.warn("failed to get message from repository:" + ex);
LOG.debug("",ex);
}
return Collections.emptyList();
}
@Override
public List<Message> selectMessageBy(Map<String, Object> fields) {
try(Connection connection = dataSource.getConnection()) {
String sql = getMessageSelectSQL(fields);
try(PreparedStatement stmt = connection.prepareStatement(sql)) {
int i = 1;
for(Map.Entry<String,Object> entry : fields.entrySet()) {
if(!entry.getKey().startsWith("orderBy") && !entry.getKey().equals("maxResults")) {
stmt.setObject(i++,entry.getValue());
} else if(entry.getKey().equals("maxResults")) {
int max = (int) entry.getValue();
stmt.setMaxRows(max);
}
}
ResultSet result = stmt.executeQuery();
return JDBCMessageMapper.getMessage(result);
}
} catch (SQLException ex) {
LOG.warn("failed to get message from repository:" + ex);
LOG.debug("",ex);
}
return Collections.emptyList();
}
@Override
public InputStream selectRepositoryBy(String columnName, String value) {
try(Connection connection = dataSource.getConnection()) {
String sql = getRepositorySelectSQL(columnName);
try(PreparedStatement stmt = connection.prepareStatement(sql)) {
stmt.setString(1,value);
ResultSet result = stmt.executeQuery();
if(result.next()) {
return result.getBinaryStream("content");
}
}
} catch (SQLException ex) {
LOG.warn("failed to get payload from repository:" + ex);
LOG.debug("",ex);
}
return null;
}
protected String getRepositorySelectSQL(String columnName) {
return "SELECT * from REPOSITORY WHERE " + columnName + "=?";
}
protected String getMessageSelectSQL(Map<String,Object> fields) {
StringBuilder sql = new StringBuilder("SELECT * FROM MESSAGE WHERE ");
String orderBy = "";
boolean first = true;
for(String column : fields.keySet()) {
switch (column) {
case "orderByAsc":
orderBy = " ORDER BY " + fields.get(column);
break;
case "orderByDesc":
orderBy = " ORDER BY " + fields.get(column) + " DESC";
break;
case "maxResults":
break;
default:
if(!first) {
sql.append(" AND ");
}
sql.append(column);
sql.append("=?");
first = false;
}
}
sql.append(orderBy);
return sql.toString();
}
protected String getMessageInsertSQL() {
return "INSERT INTO MESSAGE (message_id, message_box, message_type, cpa_id, conv_id, ref_to_message_id, time_stamp) VALUES (?,?,?,?,?,?,?)";
}
protected String getMessageUpdateSQL() {
return "UPDATE MESSAGE SET status=?, status_description=? where message_id=? and message_box=?";
}
protected String getInsertSQL() {
return "INSERT INTO repository (message_id,content_type,time_stamp,message_box,content) VALUES(?,?,?,?,?)";
}
protected abstract String [] getCreateSQL();
}