/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.jdbi.v3.sqlobject.statement; import java.lang.reflect.Method; import java.lang.reflect.Type; import java.util.Iterator; import java.util.stream.Collector; import java.util.stream.Stream; import org.jdbi.v3.core.generic.GenericTypes; import org.jdbi.v3.core.result.ResultIterable; import org.jdbi.v3.core.statement.StatementContext; import org.jdbi.v3.sqlobject.SingleValue; import static org.jdbi.v3.core.generic.GenericTypes.getErasedType; /** * Helper class used by the {@link CustomizingStatementHandler}s to assemble * the result Collection, Iterable, etc. */ abstract class ResultReturner { /** * If the return type is {@code void}, swallow results. * @param extensionType * @param method * @see ResultReturner#forMethod(Class, Method) if the return type is not void * @return */ static ResultReturner forOptionalReturn(Class<?> extensionType, Method method) { if (method.getReturnType() == void.class) { return new VoidReturner(); } return forMethod(extensionType, method); } /** * Inspect a Method for its return type, and choose a ResultReturner subclass * that handles any container that might wrap the results. * @param extensionType the type that owns the Method * @param method the method whose return type chooses the ResultReturner * @return an instance that takes a ResultIterable and constructs the return value */ static ResultReturner forMethod(Class<?> extensionType, Method method) { Type returnType = GenericTypes.resolveType(method.getGenericReturnType(), extensionType); Class<?> returnClass = getErasedType(returnType); if (Void.TYPE.equals(returnClass)) { throw new IllegalStateException(String.format( "Method %s#%s is annotated as if it should return a value, but the method is void.", method.getDeclaringClass().getName(), method.getName())); } else if (ResultIterable.class.isAssignableFrom(returnClass)) { return new ResultIterableResultReturner(returnType); } else if (Stream.class.isAssignableFrom(returnClass)) { return new StreamReturner(returnType); } else if (Iterator.class.isAssignableFrom(returnClass)) { return new IteratorResultReturner(returnType); } else if (method.isAnnotationPresent(SingleValue.class)) { return new SingleValueResultReturner(returnType); } else { return new CollectedResultReturner(returnType); } } protected abstract Object result(ResultIterable<?> iterable, StatementContext ctx); protected abstract Type elementType(StatementContext ctx); static class VoidReturner extends ResultReturner { @Override protected Object result(ResultIterable<?> iterable, StatementContext ctx) { iterable.stream().forEach(i -> {}); // Make sure to consume the result return null; } @Override protected Type elementType(StatementContext ctx) { return null; } } static class StreamReturner extends ResultReturner { private final Type elementType; StreamReturner(Type returnType) { elementType = GenericTypes.findGenericParameter(returnType, Stream.class) .orElseThrow(() -> new IllegalStateException( "Cannot reflect Stream<T> element type T in method return type " + returnType)); } @Override protected Stream<?> result(ResultIterable<?> iterable, StatementContext ctx) { return iterable.stream(); } @Override protected Type elementType(StatementContext ctx) { return elementType; } } static class CollectedResultReturner extends ResultReturner { private final Type returnType; CollectedResultReturner(Type returnType) { this.returnType = returnType; } @Override @SuppressWarnings({ "unchecked", "rawtypes" }) protected Object result(ResultIterable<?> iterable, StatementContext ctx) { Collector collector = ctx.findCollectorFor(returnType).orElse(null); if (collector != null) { return iterable.collect(collector); } return checkResult(iterable.findFirst().orElse(null), returnType); } @Override protected Type elementType(StatementContext ctx) { // if returnType is not supported by a collector factory, assume it to be a single-value return type. return ctx.findElementTypeFor(returnType).orElse(returnType); } } static class SingleValueResultReturner extends ResultReturner { private final Type returnType; SingleValueResultReturner(Type returnType) { this.returnType = returnType; } @Override protected Object result(ResultIterable<?> iterable, StatementContext ctx) { return checkResult(iterable.findFirst().orElse(null), returnType); } @Override protected Type elementType(StatementContext ctx) { return returnType; } } private static Object checkResult(Object result, Type type) { if (result == null && getErasedType(type).isPrimitive()) { throw new IllegalStateException("SQL method returns primitive " + type + ", but statement returned no results"); } return result; } static class ResultIterableResultReturner extends ResultReturner { private final Type elementType; ResultIterableResultReturner(Type returnType) { // extract T from Query<T> elementType = GenericTypes.findGenericParameter(returnType, ResultIterable.class) .orElseThrow(() -> new IllegalStateException( "Cannot reflect ResultIterable<T> element type T in method return type " + returnType)); } @Override protected Object result(ResultIterable<?> iterable, StatementContext ctx) { return iterable; } @Override protected Type elementType(StatementContext ctx) { return elementType; } } static class IteratorResultReturner extends ResultReturner { private final Type elementType; IteratorResultReturner(Type returnType) { this.elementType = GenericTypes.findGenericParameter(returnType, Iterator.class) .orElseThrow(() -> new IllegalStateException( "Cannot reflect Iterator<T> element type T in method return type " + returnType)); } @Override protected Object result(ResultIterable<?> iterable, StatementContext ctx) { return iterable.iterator(); } @Override protected Type elementType(StatementContext ctx) { return elementType; } } }