/** * Licensed to JumpMind Inc under one or more contributor * license agreements. See the NOTICE file distributed * with this work for additional information regarding * copyright ownership. JumpMind Inc licenses this file * to you under the GNU General Public License, version 3.0 (GPLv3) * (the "License"); you may not use this file except in compliance * with the License. * * You should have received a copy of the GNU General Public License, * version 3.0 (GPLv3) along with this library; if not, see * <http://www.gnu.org/licenses/>. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.jumpmind.symmetric.io; import java.sql.Connection; import java.sql.SQLException; import org.apache.commons.codec.binary.Base64; import org.apache.commons.codec.binary.Hex; import org.apache.commons.lang.StringUtils; import org.jumpmind.db.model.Column; import org.jumpmind.db.model.Table; import org.jumpmind.db.platform.DatabaseInfo; import org.jumpmind.db.platform.IDatabasePlatform; import org.jumpmind.db.sql.JdbcSqlTransaction; import org.jumpmind.db.util.BinaryEncoding; import org.jumpmind.symmetric.csv.CsvWriter; import org.jumpmind.symmetric.io.data.Batch; import org.jumpmind.symmetric.io.data.CsvData; import org.jumpmind.symmetric.io.data.CsvUtils; import org.jumpmind.symmetric.io.data.DataContext; import org.jumpmind.symmetric.io.data.DataEventType; import org.jumpmind.symmetric.io.data.writer.DataWriterStatisticConstants; import org.jumpmind.symmetric.io.data.writer.DefaultDatabaseWriter; import org.postgresql.copy.CopyIn; import org.postgresql.copy.CopyManager; import org.postgresql.core.BaseConnection; import org.springframework.jdbc.support.nativejdbc.NativeJdbcExtractor; public class PostgresBulkDatabaseWriter extends DefaultDatabaseWriter { protected NativeJdbcExtractor jdbcExtractor; protected int maxRowsBeforeFlush; protected CopyManager copyManager; protected CopyIn copyIn; protected int loadedRows = 0; protected boolean needsBinaryConversion; public PostgresBulkDatabaseWriter(IDatabasePlatform platform, NativeJdbcExtractor jdbcExtractor, int maxRowsBeforeFlush) { super(platform); this.jdbcExtractor = jdbcExtractor; this.maxRowsBeforeFlush = maxRowsBeforeFlush; } public void write(CsvData data) { DataEventType dataEventType = data.getDataEventType(); switch (dataEventType) { case INSERT: startCopy(); statistics.get(batch).increment(DataWriterStatisticConstants.STATEMENTCOUNT); statistics.get(batch).increment(DataWriterStatisticConstants.LINENUMBER); statistics.get(batch).startTimer(DataWriterStatisticConstants.DATABASEMILLIS); try { String[] parsedData = data.getParsedData(CsvData.ROW_DATA); if (needsBinaryConversion) { Column[] columns = targetTable.getColumns(); for (int i = 0; i < columns.length; i++) { if (columns[i].isOfBinaryType() && parsedData[i] != null) { if (batch.getBinaryEncoding().equals(BinaryEncoding.HEX)) { parsedData[i] = encode(Hex.decodeHex(parsedData[i].toCharArray())); } else if (batch.getBinaryEncoding().equals(BinaryEncoding.BASE64)) { parsedData[i] = encode(Base64.decodeBase64(parsedData[i].getBytes())); } } } } String formattedData = CsvUtils.escapeCsvData(parsedData, '\n', '\'', CsvWriter.ESCAPE_MODE_DOUBLED); byte[] dataToLoad = formattedData.getBytes(); copyIn.writeToCopy(dataToLoad, 0, dataToLoad.length); loadedRows++; } catch (Exception ex) { throw getPlatform().getSqlTemplate().translate(ex); } finally { statistics.get(batch).stopTimer(DataWriterStatisticConstants.DATABASEMILLIS); } break; case UPDATE: case DELETE: default: endCopy(); super.write(data); break; } if (loadedRows >= maxRowsBeforeFlush) { flush(); loadedRows = 0; } } protected void flush() { if (copyIn != null) { statistics.get(batch).startTimer(DataWriterStatisticConstants.DATABASEMILLIS); try { if (copyIn.isActive()) { copyIn.flushCopy(); } } catch (SQLException ex) { throw getPlatform().getSqlTemplate().translate(ex); } finally { statistics.get(batch).stopTimer(DataWriterStatisticConstants.DATABASEMILLIS); } } } @Override public void open(DataContext context) { super.open(context); try { JdbcSqlTransaction jdbcTransaction = (JdbcSqlTransaction) transaction; Connection conn = jdbcExtractor.getNativeConnection(jdbcTransaction.getConnection()); copyManager = new CopyManager((BaseConnection) conn); } catch (Exception ex) { throw getPlatform().getSqlTemplate().translate(ex); } } protected void startCopy() { if (copyIn == null && targetTable != null) { try { String sql = createCopyMgrSql(); if (log.isDebugEnabled()) { log.debug("starting bulk copy using: {}", sql); } copyIn = copyManager.copyIn(sql); } catch (Exception ex) { throw getPlatform().getSqlTemplate().translate(ex); } } } protected void endCopy() { if (copyIn != null) { try { flush(); } finally { try { if (copyIn.isActive()) { copyIn.endCopy(); } } catch (Exception ex) { throw getPlatform().getSqlTemplate().translate(ex); } finally { copyIn = null; } } } } @Override public boolean start(Table table) { if (super.start(table)) { /* If target table cannot be found the write method will decide * whether to ignore the write request or to log an error. No * need to report the error right now. */ if (targetTable != null) { needsBinaryConversion = false; if (!batch.getBinaryEncoding().equals(BinaryEncoding.NONE)) { for (Column column : targetTable.getColumns()) { if (column.isOfBinaryType()) { needsBinaryConversion = true; break; } } } } return true; } else { return false; } } @Override public void end(Table table) { try { endCopy(); } finally { super.end(table); } } @Override public void end(Batch batch, boolean inError) { if (inError && copyIn != null) { try { copyIn.cancelCopy(); } catch (SQLException e) { } finally { copyIn = null; } } super.end(batch, inError); } private String createCopyMgrSql() { StringBuilder sql = new StringBuilder("COPY "); DatabaseInfo dbInfo = platform.getDatabaseInfo(); String quote = dbInfo.getDelimiterToken(); String catalogSeparator = dbInfo.getCatalogSeparator(); String schemaSeparator = dbInfo.getSchemaSeparator(); sql.append(targetTable.getQualifiedTableName(quote, catalogSeparator, schemaSeparator)); sql.append("("); Column[] columns = targetTable.getColumns(); for (Column column : columns) { String columnName = column.getName(); if (StringUtils.isNotBlank(columnName)) { sql.append(quote); sql.append(columnName); sql.append(quote); sql.append(","); } } sql.replace(sql.length() - 1, sql.length(), ")"); sql.append("FROM STDIN with delimiter ',' csv quote ''''"); return sql.toString(); } protected String encode(byte[] byteData) { StringBuilder sb = new StringBuilder(); for (byte b : byteData) { int i = b & 0xff; if (i >= 0 && i <= 7) { sb.append("\\00").append(Integer.toString(i, 8)); } else if (i >= 8 && i <= 31) { sb.append("\\0").append(Integer.toString(i, 8)); } else if (i == 92 || i >= 127) { sb.append("\\").append(Integer.toString(i, 8)); } else { sb.append(Character.toChars(i)); } } return sb.toString(); } }