package com.github.davidmoten.rx.jdbc; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.SQLException; import java.sql.Statement; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Test; import org.mockito.InOrder; import org.mockito.Mockito; import com.github.davidmoten.rx.Actions; import com.github.davidmoten.rx.jdbc.exceptions.SQLRuntimeException; import com.github.davidmoten.rx.testing.TestingHelper; import rx.Observable; import rx.functions.Func1; public final class BatchingTest { @Test public void testUnmocked() { Database db = DatabaseCreator.db(); int numPeopleBefore = db.select("select count(*) from person") // .getAs(Integer.class) // .toBlocking().single(); Observable<String> names = Observable.just("NANCY", "WARREN", "ALFRED", "BARRY", "ROBERTO"); Observable<Integer> count = db.update("insert into person(name,score) values(?,0)") .dependsOn(db.beginTransaction()) // set batch size .batchSize(3) // get parameters from last query .parameters(names) // go .count() // end transaction .count(); assertTrue(db.commit(count).toBlocking().single()); int numPeople = db.select("select count(*) from person") // .getAs(Integer.class) // .toBlocking().single(); assertEquals(numPeopleBefore + 5, numPeople); } @Test public void testBatchingCanOnlyBeUsedWithinATransaction() { Database db = DatabaseCreator.db(); Observable<String> names = Observable.just("NANCY", "WARREN", "ALFRED", "BARRY", "ROBERTO"); Observable<Integer> count = db.update("insert into person(name,score) values(?,0)") // set batch size .batchSize(3) // get parameters from last query .parameters(names) // go .count().count(); count // .to(TestingHelper.<Integer> test()) // .assertError(SQLRuntimeException.class); } @Test public void testMocked() throws SQLException { String sql = "insert into person(name,score) values(?, 0)"; final Connection con = Mockito.mock(Connection.class); PreparedStatement ps = Mockito.mock(PreparedStatement.class); Mockito.when(con.prepareStatement(sql, Statement.NO_GENERATED_KEYS)).thenReturn(ps); Mockito.when(ps.executeBatch()) // .thenReturn(new int[] { 1, 2, 3 }) // .thenReturn(new int[] { 4, 5 }); Mockito.when(con.getAutoCommit()).thenReturn(false); Mockito.when(con.isClosed()).thenReturn(false); ConnectionProvider cp = createConnectionProvider(con); Database db = Database.from(cp); Observable<String> names = Observable.just("NANCY", "WARREN", "ALFRED", "BARRY", "ROBERTO"); AtomicInteger records = new AtomicInteger(); Observable<Integer> count = db.update(sql) // .dependsOn(db.beginTransaction()) // set batch size .batchSize(3) // get parameters from last query .parameters(names) // go .count() // end transaction .toList() // sum record counts .map(new Func1<List<Integer>, Integer>() { @Override public Integer call(List<Integer> list) { return sum(list); } }) // set result to variable .doOnNext(Actions.setAtomic(records)) // .count(); db.commit(count).toBlocking().single(); InOrder in = Mockito.inOrder(con, ps); in.verify(con, Mockito.times(1)).prepareStatement(sql, Statement.NO_GENERATED_KEYS); in.verify(ps, Mockito.times(1)).setObject(1, "NANCY"); in.verify(ps, Mockito.times(1)).addBatch(); in.verify(ps, Mockito.times(1)).setObject(1, "WARREN"); in.verify(ps, Mockito.times(1)).addBatch(); in.verify(ps, Mockito.times(1)).setObject(1, "ALFRED"); in.verify(ps, Mockito.times(1)).addBatch(); in.verify(ps, Mockito.times(1)).executeBatch(); in.verify(ps, Mockito.times(1)).setObject(1, "BARRY"); in.verify(ps, Mockito.times(1)).addBatch(); in.verify(ps, Mockito.times(1)).setObject(1, "ROBERTO"); in.verify(ps, Mockito.times(1)).addBatch(); in.verify(ps, Mockito.times(1)).executeBatch(); // in.verify(con, Mockito.times(1)).commit(); in.verify(con, Mockito.times(1)).isClosed(); in.verify(con, Mockito.times(1)).close(); in.verifyNoMoreInteractions(); assertFalse(db.connectionProvider() instanceof ConnectionProviderBatch); assertEquals(1 + 2 + 3 + 4 + 5, records.get()); } private static int sum(List<Integer> list) { int sum = 0; for (Integer n : list) { sum += n; } return sum; } @Test(expected = IllegalArgumentException.class) public void cannotReturnGeneratedKeysWhenBatching() { Database db = DatabaseCreator.db(); Observable<String> names = Observable.just("NANCY"); db.update("insert into person(name,score) values(?,0)").dependsOn(db.beginTransaction()) // set batch size .batchSize(3) // get parameters from last query .parameters(names) // .returnGeneratedKeys(); } private static ConnectionProvider createConnectionProvider(final Connection con) { return new ConnectionProvider() { @Override public Connection get() { return con; } @Override public void close() { } }; } }