/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.jdbi.v3.sqlobject.statement; import java.lang.annotation.Annotation; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; import java.lang.reflect.Method; import java.lang.reflect.Parameter; import java.lang.reflect.Type; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.NoSuchElementException; import java.util.function.Function; import java.util.stream.IntStream; import java.util.stream.Stream; import org.jdbi.v3.core.Handle; import org.jdbi.v3.core.extension.HandleSupplier; import org.jdbi.v3.core.generic.GenericTypes; import org.jdbi.v3.core.internal.IterableLike; import org.jdbi.v3.core.mapper.RowMapper; import org.jdbi.v3.core.result.ResultIterable; import org.jdbi.v3.core.result.ResultIterator; import org.jdbi.v3.core.statement.PreparedBatch; import org.jdbi.v3.core.statement.StatementContext; import org.jdbi.v3.core.statement.UnableToCreateStatementException; import org.jdbi.v3.sqlobject.SingleValue; import org.jdbi.v3.sqlobject.SqlMethodAnnotation; import org.jdbi.v3.sqlobject.UnableToCreateSqlObjectException; /** * Annotate a method to indicate that it will create and execute a SQL batch. At least one * bound argument must be an Iterator or Iterable, values from this will be taken and applied * to each row of the batch. Non iterable bound arguments will be treated as constant values and * bound to each row. * <p> * Unfortunately, because of how batches work, statement customizers and sql statement customizers * which affect SQL generation will *not* work with batches. This primarily effects statement location * and rewriting, which will always use the values defined on the bound Handle. * <p> * If you want to chunk up the logical batch into a number of smaller batches (say around 1000 rows at * a time in order to not wreck havoc on the transaction log, you should see * {@link BatchChunkSize} */ @Retention(RetentionPolicy.RUNTIME) @Target({ElementType.METHOD}) @SqlMethodAnnotation(SqlBatch.Impl.class) public @interface SqlBatch { /** * @return the SQL string (or name) */ String value() default ""; /** * @return whether to execute the batch chunks in a transaction. Default is true (and it will be strange if you * want otherwise). */ boolean transactional() default true; class Impl extends CustomizingStatementHandler<PreparedBatch> { private final SqlBatch sqlBatch; private final ChunkSizeFunction batchChunkSize; private final Function<PreparedBatch, ResultIterator<?>> batchIntermediate; private final ResultReturner magic; public Impl(Class<?> sqlObjectType, Method method) { super(sqlObjectType, method); this.sqlBatch = method.getAnnotation(SqlBatch.class); this.batchChunkSize = determineBatchChunkSize(sqlObjectType, method); final GetGeneratedKeys getGeneratedKeys = method.getAnnotation(GetGeneratedKeys.class); if (getGeneratedKeys == null) { if (!returnTypeIsValid(method.getReturnType())) { throw new UnableToCreateSqlObjectException(invalidReturnTypeMessage(method)); } Function<PreparedBatch,ResultIterator<?>> modCounts = PreparedBatch::executeAndGetModCount; batchIntermediate = method.getReturnType().equals(boolean[].class) ? mapToBoolean(modCounts) : modCounts; magic = ResultReturner.forOptionalReturn(sqlObjectType, method); } else { String[] columnNames = getGeneratedKeys.value(); magic = ResultReturner.forMethod(sqlObjectType, method); if (method.isAnnotationPresent(UseRowMapper.class)) { RowMapper<?> mapper = rowMapperFor(method.getAnnotation(UseRowMapper.class)); batchIntermediate = batch -> batch.executeAndReturnGeneratedKeys(columnNames) .map(mapper) .iterator(); } else { batchIntermediate = batch -> batch.executeAndReturnGeneratedKeys(columnNames) .mapTo(magic.elementType(batch.getContext())) .iterator(); } } } private Function<PreparedBatch, ResultIterator<?>> mapToBoolean(Function<PreparedBatch, ResultIterator<?>> modCounts) { return modCounts.andThen(iterator -> new ResultIterator<Boolean>() { @Override public boolean hasNext() { return iterator.hasNext(); } @Override public Boolean next() { return ((Integer) iterator.next()) > 0; } @Override public void close() { iterator.close(); } @Override public StatementContext getContext() { return iterator.getContext(); } }); } private ChunkSizeFunction determineBatchChunkSize(Class<?> sqlObjectType, Method method) { // this next big if chain determines the batch chunk size. It looks from most specific // scope to least, that is: as an argument, then on the method, then on the class, // then default to Integer.MAX_VALUE int batchChunkSizeParameterIndex; if ((batchChunkSizeParameterIndex = indexOfBatchChunkSizeParameter(method)) >= 0) { return new ParamBasedChunkSizeFunction(batchChunkSizeParameterIndex); } else if (method.isAnnotationPresent(BatchChunkSize.class)) { final int size = method.getAnnotation(BatchChunkSize.class).value(); if (size <= 0) { throw new IllegalArgumentException("Batch chunk size must be >= 0"); } return new ConstantChunkSizeFunction(size); } else if (sqlObjectType.isAnnotationPresent(BatchChunkSize.class)) { final int size = sqlObjectType.getAnnotation(BatchChunkSize.class).value(); return new ConstantChunkSizeFunction(size); } else { return new ConstantChunkSizeFunction(Integer.MAX_VALUE); } } private int indexOfBatchChunkSizeParameter(Method method) { Annotation[][] parameterAnnotations = method.getParameterAnnotations(); return IntStream.range(0, parameterAnnotations.length) .filter(i -> Stream.of(parameterAnnotations[i]).anyMatch(BatchChunkSize.class::isInstance)) .findFirst() .orElse(-1); } @Override PreparedBatch createStatement(Handle handle, String locatedSql) { return handle.prepareBatch(locatedSql); } @Override void configureReturner(PreparedBatch stmt, SqlObjectStatementConfiguration cfg) { } @Override Type getParameterType(Parameter parameter) { Type type = super.getParameterType(parameter); if (!parameter.isAnnotationPresent(SingleValue.class)) { Class<?> erasedType = GenericTypes.getErasedType(type); if (Iterable.class.isAssignableFrom(erasedType)) { return GenericTypes.findGenericParameter(type, Iterable.class).get(); } else if (Iterator.class.isAssignableFrom(erasedType)) { return GenericTypes.findGenericParameter(type, Iterator.class).get(); } else if (GenericTypes.isArray(type)) { return ((Class<?>) type).getComponentType(); } } return type; } @Override public Object invoke(Object target, Object[] args, HandleSupplier h) { final Handle handle = h.getHandle(); final String sql = locateSql(handle); final int chunkSize = batchChunkSize.call(args); final Iterator<Object[]> batchArgs = zipArgs(getMethod(), args); ResultIterator<Object> result; if (batchArgs.hasNext()) { result = new ResultIterator<Object>() { ResultIterator<?> batchResult; boolean closed = false; { hasNext(); // Ensure our batchResult is prepared, so we can get its context } @Override public boolean hasNext() { if (closed) { throw new IllegalStateException("closed"); } // first, any elements already buffered? if (batchResult != null) { if (batchResult.hasNext()) { return true; } // no more in this chunk, release resources batchResult.close(); } // more chunks? if (!batchArgs.hasNext()) { return false; } // execute a single chunk and buffer PreparedBatch batch = handle.prepareBatch(sql); for (int i = 0; i < chunkSize && batchArgs.hasNext(); i++) { applyCustomizers(batch, batchArgs.next()); batch.add(); } batchResult = executeBatch(handle, batch); return hasNext(); // recurse to ensure we actually got elements } @Override public Object next() { if (closed) { throw new IllegalStateException("closed"); } if (!hasNext()) { throw new NoSuchElementException(); } return batchResult.next(); } @Override public StatementContext getContext() { return batchResult.getContext(); } @Override public void close() { closed = true; batchResult.close(); } }; } else { PreparedBatch dummy = handle.prepareBatch(sql); result = new ResultIterator<Object>() { @Override public void close() { // no op } @Override public StatementContext getContext() { return dummy.getContext(); } @Override public boolean hasNext() { return false; } @Override public Object next() { throw new NoSuchElementException(); } }; } ResultIterable<Object> iterable = ResultIterable.of(result); return magic.result(iterable, result.getContext()); } private Iterator<Object[]> zipArgs(Method method, Object[] args) { boolean foundIterator = false; List<Iterator<?>> extras = new ArrayList<>(); for (int paramIdx = 0; paramIdx < method.getParameterCount(); paramIdx++) { final boolean singleValue = method.getParameters()[paramIdx].isAnnotationPresent(SingleValue.class); final Object arg = args[paramIdx]; if (!singleValue && IterableLike.isIterable(arg)) { extras.add(IterableLike.of(arg)); foundIterator = true; } else { extras.add(Stream.generate(() -> arg).iterator()); } } if (!foundIterator) { throw new UnableToCreateStatementException("@SqlBatch method has no Iterable or array parameters," + " did you mean @SqlQuery?", null, null); } final Object[] sharedArg = new Object[args.length]; return new Iterator<Object[]>() { @Override public boolean hasNext() { for (Iterator<?> extra : extras) { if (!extra.hasNext()) { return false; } } return true; } @Override public Object[] next() { for (int i = 0; i < extras.size(); i++) { sharedArg[i] = extras.get(i).next(); } return sharedArg; } }; } private ResultIterator<?> executeBatch(final Handle handle, final PreparedBatch batch) { if (!handle.isInTransaction() && sqlBatch.transactional()) { // it is safe to use same prepared batch as the inTransaction passes in the same // Handle instance. return handle.inTransaction(c -> batchIntermediate.apply(batch)); } else { return batchIntermediate.apply(batch); } } private interface ChunkSizeFunction { int call(Object[] args); } private static class ConstantChunkSizeFunction implements ChunkSizeFunction { private final int value; ConstantChunkSizeFunction(int value) { this.value = value; } @Override public int call(Object[] args) { return value; } } private static class ParamBasedChunkSizeFunction implements ChunkSizeFunction { private final int index; ParamBasedChunkSizeFunction(int index) { this.index = index; } @Override public int call(Object[] args) { return (Integer) args[index]; } } private static boolean returnTypeIsValid(Class<?> type) { if (type.equals(Void.TYPE)) { return true; } if (type.isArray()) { Class<?> componentType = type.getComponentType(); return componentType.equals(Integer.TYPE) || componentType.equals(Boolean.TYPE); } return false; } private static String invalidReturnTypeMessage(Method method) { return method.getDeclaringClass() + "." + method.getName() + " method is annotated with @SqlBatch so should return void, int[], or boolean[] but is returning: " + method.getReturnType(); } } }