/** * Licensed to JumpMind Inc under one or more contributor * license agreements. See the NOTICE file distributed * with this work for additional information regarding * copyright ownership. JumpMind Inc licenses this file * to you under the GNU General Public License, version 3.0 (GPLv3) * (the "License"); you may not use this file except in compliance * with the License. * * You should have received a copy of the GNU General Public License, * version 3.0 (GPLv3) along with this library; if not, see * <http://www.gnu.org/licenses/>. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.jumpmind.symmetric.io; import java.math.BigDecimal; import java.text.DecimalFormat; import java.text.SimpleDateFormat; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Date; import java.util.List; import java.util.Map; import org.apache.commons.lang.StringUtils; import org.jumpmind.db.model.Table; import org.jumpmind.db.platform.AbstractDatabasePlatform; import org.jumpmind.db.platform.DatabaseNamesConstants; import org.jumpmind.db.platform.IDatabasePlatform; import org.jumpmind.db.platform.sqlite.SqliteDatabasePlatform; import org.jumpmind.db.util.BinaryEncoding; import org.jumpmind.symmetric.io.data.Batch; import org.jumpmind.symmetric.io.data.Batch.BatchType; import org.jumpmind.symmetric.io.data.writer.DataWriterStatisticConstants; import org.jumpmind.symmetric.io.data.writer.DatabaseWriterSettings; import org.jumpmind.symmetric.io.data.writer.DefaultDatabaseWriter; import org.jumpmind.symmetric.io.data.writer.IgnoreBatchException; import org.jumpmind.symmetric.io.data.CsvData; import org.jumpmind.symmetric.io.data.DataContext; import org.jumpmind.symmetric.io.data.IDataWriter; import org.jumpmind.util.Statistics; import org.junit.Assert; abstract public class AbstractWriterTest { protected static IDatabasePlatform platform; protected boolean errorExpected = true; protected final static String TEST_TABLE = "test_dataloader_table"; protected final static String[] TEST_KEYS = { "id" }; protected final static String[] TEST_COLUMNS = { "id", "string_value", "string_required_value", "char_value", "char_required_value", "date_value", "time_value", "boolean_value", "integer_value", "decimal_value", "double_value" }; protected static long batchId = 10000; protected static long sequenceId = 10000; protected DatabaseWriterSettings writerSettings = new DatabaseWriterSettings(); protected IDataWriter lastDataWriterUsed; protected synchronized long getNextBatchId() { return ++batchId; } protected synchronized long getBatchId() { return batchId; } protected synchronized String getNextId() { return String.valueOf(++sequenceId); } protected synchronized String getId() { return String.valueOf(sequenceId); } protected Table buildSourceTable(String tableName, String[] keyNames, String[] columnNames) { return Table.buildTable(tableName, keyNames, columnNames); } protected void writeData(CsvData data, String[] expectedValues) { writeData(data, expectedValues, TEST_COLUMNS); } protected String getTestTable() { return TEST_TABLE; } protected void writeData(CsvData... data) { Table table = buildSourceTable(TEST_TABLE, TEST_KEYS, TEST_COLUMNS); writeData(new TableCsvData(table, data)); } protected void writeData(CsvData data, String[] expectedValues, String[] columnNames) { writeData(data, expectedValues, getTestTable(), TEST_KEYS, columnNames); } protected void writeData(CsvData data, String[] expectedValues, String tableName, String[] keyNames, String[] columnNames) { Table table = buildSourceTable(tableName, keyNames, columnNames); writeData(new TableCsvData(table, data)); String[] pkData = data.getParsedData(CsvData.ROW_DATA); if (pkData == null) { pkData = data.getParsedData(CsvData.PK_DATA); } assertTestTableEquals(pkData[0], expectedValues); } protected long writeData(TableCsvData... datas) { return writeData(new DefaultDatabaseWriter(platform, writerSettings), datas); } protected long writeData(IDataWriter writer, TableCsvData... datas) { this.lastDataWriterUsed = writer; DataContext context = new DataContext(); writer.open(context); try { for (TableCsvData tableCsvData : datas) { Batch batch = new Batch(BatchType.LOAD, getNextBatchId(), "default", BinaryEncoding.BASE64, "00000", "00001", false); try { writer.start(batch); if (writer.start(tableCsvData.table)) { for (CsvData d : tableCsvData.data) { writer.write(d); } writer.end(tableCsvData.table); } writer.end(batch, false); } catch (IgnoreBatchException ex) { writer.end(batch, false); } catch (Exception ex) { writer.end(batch, true); if (!isErrorExpected()) { if (ex instanceof RuntimeException) { throw (RuntimeException) ex; } else { throw new RuntimeException(ex); } } } } } finally { writer.close(); } long statementCount = 0; Collection<Statistics> stats = writer.getStatistics().values(); for (Statistics statistics : stats) { statementCount += statistics.get(DataWriterStatisticConstants.STATEMENTCOUNT); } return statementCount; } protected void assertTestTableEquals(String testTableId, String[] expectedValues) { String sql = "select " + getSelect(TEST_COLUMNS) + " from " + getTestTable() + " where " + getWhere(TEST_KEYS); Map<String, Object> results = platform.getSqlTemplate().queryForMap(sql, new Long(testTableId)); if (expectedValues != null) { expectedValues[1] = translateExpectedString(expectedValues[1], false); expectedValues[2] = translateExpectedString(expectedValues[2], true); expectedValues[3] = translateExpectedCharString(expectedValues[3], 50, false); expectedValues[4] = translateExpectedCharString(expectedValues[4], 50, true); } assertEquals(TEST_COLUMNS, expectedValues, results); } protected String getSelect(String[] columns) { StringBuilder str = new StringBuilder(); for (int i = 0; i < columns.length; i++) { str.append(columns[i]).append(i + 1 < columns.length ? ", " : ""); } return str.toString(); } protected String getWhere(String[] columns) { StringBuilder str = new StringBuilder(); for (int i = 0; i < columns.length; i++) { str.append(columns[i]).append(" = ?").append(i + 1 < columns.length ? "," : ""); } return str.toString(); } protected String translateExpectedString(String value, boolean isRequired) { if (isRequired && (value == null || (value.equals("") && platform.getDatabaseInfo() .isEmptyStringNulled()))) { return AbstractDatabasePlatform.REQUIRED_FIELD_NULL_SUBSTITUTE; } else if (value != null && value.equals("") && platform.getDatabaseInfo().isEmptyStringNulled()) { return null; } return value; } protected String translateExpectedCharString(String value, int size, boolean isRequired) { if (isRequired && value == null) { if (!platform.getDatabaseInfo().isRequiredCharColumnEmptyStringSameAsNull() || platform.getDatabaseInfo().isEmptyStringNulled()) { value = AbstractDatabasePlatform.REQUIRED_FIELD_NULL_SUBSTITUTE; } } if (value != null && ((StringUtils.isBlank(value) && platform.getDatabaseInfo() .isBlankCharColumnSpacePadded()) || (StringUtils.isNotBlank(value) && platform .getDatabaseInfo().isNonBlankCharColumnSpacePadded()))) { return StringUtils.rightPad(value, size); } else if (value != null && platform.getDatabaseInfo().isCharColumnSpaceTrimmed()) { return value.replaceFirst(" *$", ""); } return value; } protected void assertEquals(String[] name, String[] expected, Map<String, Object> results) { if (expected == null) { Assert.assertNull("Expected empty results. " + printDatabase(), results); } else { Assert.assertNotNull(String.format("Did not find the expected row: %s.", Arrays.toString(expected)), results); for (int i = 0; i < expected.length; i++) { Object resultObj = results.get(name[i]); String resultValue = null; char decimal = ((DecimalFormat) DecimalFormat.getInstance()) .getDecimalFormatSymbols().getDecimalSeparator(); if ((resultObj instanceof Double || resultObj instanceof BigDecimal) && expected[i].indexOf(decimal) != -1) { DecimalFormat df = new DecimalFormat("0.00####################################"); resultValue = df.format(resultObj); } else if (resultObj instanceof Date) { SimpleDateFormat df = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.000"); resultValue = df.format(resultObj); } else if (resultObj instanceof Boolean) { resultValue = ((Boolean) resultObj) ? "1" : "0"; } else if (resultObj instanceof Double) { resultValue = resultObj.toString(); if (platform instanceof SqliteDatabasePlatform) { expected[i] = new Double(expected[i]).toString(); } } else if (resultObj != null) { resultValue = resultObj.toString(); } Assert.assertEquals(name[i] + ". " + printDatabase(), expected[i], resultValue); } } } protected String printDatabase() { return " The database we are testing against is " + platform.getName() + "."; } protected boolean isOracle() { return DatabaseNamesConstants.ORACLE.equals(platform.getName()); } public void setErrorExpected(boolean errorExpected) { this.errorExpected = errorExpected; } public boolean isErrorExpected() { return errorExpected; } public Map<String,Object> queryForRow(String id) { return platform.getSqlTemplate().queryForMap("select * from " + TEST_TABLE + " where id=?", new Integer(id)); } protected class TableCsvData { Table table; List<CsvData> data; public TableCsvData(Table table, CsvData... csvDatas) { this.table = table; this.data = new ArrayList<CsvData>(); for (CsvData csvData : csvDatas) { this.data.add(csvData); } } public TableCsvData(Table table, List<CsvData> data) { this.table = table; this.data = data; } } protected long countRows(String tableName) { return platform.getSqlTemplate().queryForInt(String.format("select count(*) from %s", tableName)); } }