package org.mariadb.jdbc; import org.junit.Assume; import org.junit.BeforeClass; import org.junit.Test; import java.sql.*; import static org.junit.Assert.*; public class BasicBatchTest extends BaseTest { /** * Tables initialisation. * * @throws SQLException exception */ @BeforeClass() public static void initClass() throws SQLException { createTable("test_batch", "id int not null primary key auto_increment, test varchar(10)"); createTable("test_batch2", "id int not null primary key auto_increment, test varchar(10)"); createTable("test_batch3", "id int not null primary key auto_increment, test varchar(10)"); createTable("batchUpdateException", "i int,PRIMARY KEY (i)"); createTable("batchPrepareUpdateException", "i int,PRIMARY KEY (i)"); createTable("rewritetest", "id int not null primary key, a varchar(10), b int", "engine=innodb"); createTable("rewritetest2", "id int not null primary key, a varchar(10), b int", "engine=innodb"); createTable("bug501452", "id int not null primary key, value varchar(20)"); } @Test public void batchTest() throws SQLException { Assume.assumeFalse(sharedIsRewrite()); PreparedStatement ps = sharedConnection.prepareStatement("insert into test_batch values (null, ?)", Statement.RETURN_GENERATED_KEYS); ps.setString(1, "aaa"); ps.addBatch(); ps.setString(1, "bbb"); ps.addBatch(); ps.setString(1, "ccc"); ps.addBatch(); int[] batchResult = ps.executeBatch(); ResultSet rs1 = ps.getGeneratedKeys(); for (int count = 1; count <= 3; count++) { assertTrue(rs1.next()); assertTrue(String.valueOf(count).equalsIgnoreCase(rs1.getString(1))); } for (int unitInsertNumber : batchResult) { assertEquals(1, unitInsertNumber); } ps.setString(1, "aaa"); ps.addBatch(); ps.setString(1, "bbb"); ps.addBatch(); ps.setString(1, "ccc"); ps.addBatch(); batchResult = ps.executeBatch(); for (int unitInsertNumber : batchResult) { assertEquals(1, unitInsertNumber); } final ResultSet rs = sharedConnection.createStatement().executeQuery("select * from test_batch"); ps.executeQuery("SELECT 1"); rs1 = ps.getGeneratedKeys(); assertFalse(rs1.next()); assertEquals(true, rs.next()); assertEquals("aaa", rs.getString(2)); assertEquals(true, rs.next()); assertEquals("bbb", rs.getString(2)); assertEquals(true, rs.next()); assertEquals("ccc", rs.getString(2)); } @Test public void batchTestStmtUsingPipeline() throws SQLException { batchTestStmt(sharedConnection); } @Test public void batchTestStmtWithoutPipeline() throws SQLException { try (Connection connection = setConnection("&useBatchMultiSend=false")) { batchTestStmt(connection); } } private void batchTestStmt(Connection connection) throws SQLException { Statement stmt = connection.createStatement(); stmt.execute("truncate test_batch2"); stmt.addBatch("insert into test_batch2 values (null, 'hej1')"); stmt.addBatch("insert into test_batch2 values (null, 'hej2')"); stmt.addBatch("insert into test_batch2 values (null, 'hej3')"); stmt.addBatch("insert into test_batch2 values (null, 'hej4')"); int[] inserts = stmt.executeBatch(); assertEquals(4, inserts.length); assertEquals(1, inserts[0]); assertEquals(1, inserts[1]); assertEquals(1, inserts[2]); assertEquals(1, inserts[3]); ResultSet rs = sharedConnection.createStatement().executeQuery("select * from test_batch2"); for (int i = 1; i <= 4; i++) { assertEquals(true, rs.next()); assertEquals(i, rs.getInt(1)); assertEquals("hej" + i, rs.getString(2)); } assertEquals(false, rs.next()); } @Test public void batchUpdateException() throws Exception { Statement st = sharedConnection.createStatement(); st.addBatch("insert into batchUpdateException values(1)"); st.addBatch("insert into batchUpdateException values(2)"); st.addBatch("insert into batchUpdateException values(1)"); // will fail, duplicate primary key st.addBatch("insert into batchUpdateException values(3)"); try { st.executeBatch(); fail("exception should be throw above"); } catch (BatchUpdateException bue) { int[] updateCounts = bue.getUpdateCounts(); assertEquals(4, updateCounts.length); if (sharedIsRewrite()) { assertEquals(1, updateCounts[0]); assertEquals(1, updateCounts[1]); assertEquals(Statement.EXECUTE_FAILED, updateCounts[2]); assertEquals(Statement.EXECUTE_FAILED, updateCounts[3]); } else { //prepare or allowMultiQueries options assertEquals(1, updateCounts[0]); assertEquals(1, updateCounts[1]); assertEquals(Statement.EXECUTE_FAILED, updateCounts[2]); assertEquals(1, updateCounts[3]); } assertTrue(bue.getCause() instanceof SQLIntegrityConstraintViolationException); } } @Test public void batchPrepareUpdateException() throws Exception { PreparedStatement st = sharedConnection.prepareStatement("insert into batchPrepareUpdateException values(?)"); st.setInt(1, 1); st.addBatch(); st.setInt(1, 2); st.addBatch(); st.setInt(1, 1); // will fail, duplicate primary key st.addBatch(); st.setInt(1, 3); st.addBatch(); try { st.executeBatch(); fail("exception should be throw above"); } catch (BatchUpdateException bue) { int[] updateCounts = bue.getUpdateCounts(); assertEquals(4, updateCounts.length); if (sharedIsRewrite()) { assertEquals(Statement.EXECUTE_FAILED, updateCounts[0]); assertEquals(Statement.EXECUTE_FAILED, updateCounts[1]); assertEquals(Statement.EXECUTE_FAILED, updateCounts[2]); assertEquals(Statement.EXECUTE_FAILED, updateCounts[3]); } else { //prepare or allowMultiQueries options assertEquals(1, updateCounts[0]); assertEquals(1, updateCounts[1]); assertEquals(Statement.EXECUTE_FAILED, updateCounts[2]); assertEquals(1, updateCounts[3]); } assertTrue(bue.getCause() instanceof SQLIntegrityConstraintViolationException); } } @Test public void testBatchLoop() throws SQLException { PreparedStatement ps = sharedConnection.prepareStatement("insert into rewritetest values (?,?,?)"); for (int i = 0; i < 10; i++) { ps.setInt(1, i); ps.setString(2, "bbb" + i); ps.setInt(3, 30 + i); ps.addBatch(); } ps.executeBatch(); ResultSet rs = sharedConnection.createStatement().executeQuery("select * from rewritetest"); int counter = 0; while (rs.next()) { assertEquals(counter++, rs.getInt("id")); } assertEquals(10, counter); } @Test public void testBatchLoopWithDupKey() throws SQLException { PreparedStatement ps = sharedConnection.prepareStatement( "insert into rewritetest2 values (?,?,?) on duplicate key update a=values(a)"); for (int i = 0; i < 2; i++) { ps.setInt(1, 0); ps.setString(2, "bbb" + i); ps.setInt(3, 30 + i); ps.addBatch(); } ps.executeBatch(); ResultSet rs = sharedConnection.createStatement().executeQuery("select * from rewritetest2"); int counter = 0; while (rs.next()) { assertEquals(counter++, rs.getInt("id")); } assertEquals(1, counter); } @Test public void testBug501452() throws SQLException { PreparedStatement ps = sharedConnection.prepareStatement("insert into bug501452 (id,value) values (?,?)"); ps.setObject(1, 1); ps.setObject(2, "value for 1"); ps.addBatch(); ps.executeBatch(); ps.setObject(1, 2); ps.setObject(2, "value for 2"); ps.addBatch(); ps.executeBatch(); } @Test public void testMultipleStatementBatch() throws SQLException { try (Connection connection = setConnection("&sessionVariables=auto_increment_increment=2&allowMultiQueries=true")) { Statement stmt = connection.createStatement(); stmt.addBatch("INSERT INTO test_batch3(test) value ('a')"); stmt.addBatch("INSERT INTO test_batch3(test) value ('b')"); stmt.addBatch("INSERT INTO test_batch3(test) value ('a'), ('e')"); stmt.addBatch("UPDATE test_batch3 set test='c' WHERE test = 'a'"); stmt.addBatch("UPDATE test_batch3 set test='d' WHERE test = 'b'"); stmt.addBatch("INSERT INTO test_batch3(test) value ('e')"); int[] updateCount = stmt.executeBatch(); assertEquals(6, updateCount.length); assertEquals(1, updateCount[0]); assertEquals(1, updateCount[1]); assertEquals(2, updateCount[2]); assertEquals(2, updateCount[3]); assertEquals(1, updateCount[4]); assertEquals(1, updateCount[5]); assertEquals(-1, stmt.getUpdateCount()); assertFalse(stmt.getMoreResults()); ResultSet resultSet = stmt.getGeneratedKeys(); assertTrue(resultSet.next()); assertEquals(1, resultSet.getInt(1)); assertTrue(resultSet.next()); assertEquals(3, resultSet.getInt(1)); assertTrue(resultSet.next()); assertEquals(5, resultSet.getInt(1)); assertTrue(resultSet.next()); assertEquals(7, resultSet.getInt(1)); assertTrue(resultSet.next()); assertEquals(9, resultSet.getInt(1)); assertFalse(resultSet.next()); } } }