/*
* Copyright 2013-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.repository.support;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import org.springframework.core.MethodParameter;
import org.springframework.core.convert.ConversionException;
import org.springframework.core.convert.ConversionService;
import org.springframework.core.convert.TypeDescriptor;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.data.repository.core.CrudMethods;
import org.springframework.data.repository.core.RepositoryMetadata;
import org.springframework.data.repository.query.Param;
import org.springframework.data.repository.util.QueryExecutionConverters;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.MultiValueMap;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;
/**
* Base {@link RepositoryInvoker} using reflection to invoke methods on Spring Data Repositories.
*
* @author Oliver Gierke
* @since 1.10
*/
class ReflectionRepositoryInvoker implements RepositoryInvoker {
private static final AnnotationAttribute PARAM_ANNOTATION = new AnnotationAttribute(Param.class);
private static final String NAME_NOT_FOUND = "Unable to detect parameter names for query method %s! Use @Param or compile with -parameters on JDK 8.";
private final Object repository;
private final CrudMethods methods;
private final Class<?> idType;
private final ConversionService conversionService;
/**
* Creates a new {@link ReflectionRepositoryInvoker} for the given repository, {@link RepositoryMetadata} and
* {@link ConversionService}.
*
* @param repository must not be {@literal null}.
* @param metadata must not be {@literal null}.
* @param conversionService must not be {@literal null}.
*/
public ReflectionRepositoryInvoker(Object repository, RepositoryMetadata metadata,
ConversionService conversionService) {
Assert.notNull(repository, "Repository must not be null!");
Assert.notNull(metadata, "RepositoryMetadata must not be null!");
Assert.notNull(conversionService, "ConversionService must not be null!");
this.repository = repository;
this.methods = metadata.getCrudMethods();
this.idType = metadata.getIdType();
this.conversionService = conversionService;
}
/*
* (non-Javadoc)
* @see org.springframework.data.rest.core.invoke.RepositoryInvocationInformation#hasFindAllMethod()
*/
@Override
public boolean hasFindAllMethod() {
return methods.hasFindAllMethod();
}
/*
* (non-Javadoc)
* @see org.springframework.data.repository.support.RepositoryInvoker#invokeSortedFindAll(java.util.Optional)
*/
@Override
public Iterable<Object> invokeFindAll(Sort sort) {
return invokeFindAllReflectively(sort);
}
/*
* (non-Javadoc)
* @see org.springframework.data.repository.support.RepositoryInvoker#invokePagedFindAll(java.util.Optional)
*/
@Override
public Iterable<Object> invokeFindAll(Pageable pageable) {
return invokeFindAllReflectively(pageable);
}
/*
* (non-Javadoc)
* @see org.springframework.data.rest.core.invoke.RepositoryInvocationInformation#hasSaveMethod()
*/
@Override
public boolean hasSaveMethod() {
return methods.hasSaveMethod();
}
/* (non-Javadoc)
* @see org.springframework.data.rest.core.invoke.RepositoryInvoker#invokeSave(java.lang.Object)
*/
@Override
public <T> T invokeSave(T object) {
Method method = methods.getSaveMethod()//
.orElseThrow(() -> new IllegalStateException("Repository doesn't have a save-method declared!"));
return invoke(method, object);
}
/*
* (non-Javadoc)
* @see org.springframework.data.rest.core.invoke.RepositoryInvocationInformation#hasFindOneMethod()
*/
@Override
public boolean hasFindOneMethod() {
return methods.hasFindOneMethod();
}
/*
* (non-Javadoc)
* @see org.springframework.data.repository.support.RepositoryInvoker#invokeFindById(java.lang.Object)
*/
@Override
public <T> Optional<T> invokeFindById(Object id) {
Method method = methods.getFindOneMethod()//
.orElseThrow(() -> new IllegalStateException("Repository doesn't have a find-one-method declared!"));
return returnAsOptional(invoke(method, convertId(id)));
}
/*
* (non-Javadoc)
* @see org.springframework.data.rest.core.invoke.RepositoryInvocationInformation#hasDeleteMethod()
*/
@Override
public boolean hasDeleteMethod() {
return methods.hasDelete();
}
/*
* (non-Javadoc)
* @see org.springframework.data.repository.support.RepositoryInvoker#invokeDeleteById(java.lang.Object)
*/
@Override
public void invokeDeleteById(Object id) {
Assert.notNull(id, "Identifier must not be null!");
Method method = methods.getDeleteMethod()
.orElseThrow(() -> new IllegalStateException("Repository doesn't have a delete-method declared!"));
Class<?> parameterType = method.getParameterTypes()[0];
List<Class<?>> idTypes = Arrays.asList(idType, Object.class);
if (idTypes.contains(parameterType)) {
invoke(method, convertId(id));
} else {
invoke(method, this.<Object> invokeFindById(id).orElse(null));
}
}
/*
* (non-Javadoc)
* @see org.springframework.data.rest.core.invoke.RepositoryInvoker#invokeQueryMethod(java.lang.reflect.Method, java.util.Map, org.springframework.data.domain.Pageable, org.springframework.data.domain.Sort)
*/
@Override
public Optional<Object> invokeQueryMethod(Method method, MultiValueMap<String, ? extends Object> parameters,
Pageable pageable, Sort sort) {
Assert.notNull(method, "Method must not be null!");
Assert.notNull(parameters, "Parameters must not be null!");
Assert.notNull(pageable, "Pageable must not be null!");
Assert.notNull(sort, "Sort must not be null!");
ReflectionUtils.makeAccessible(method);
return returnAsOptional(invoke(method, prepareParameters(method, parameters, pageable, sort)));
}
private Object[] prepareParameters(Method method, MultiValueMap<String, ? extends Object> rawParameters,
Pageable pageable, Sort sort) {
List<MethodParameter> parameters = new MethodParameters(method, Optional.of(PARAM_ANNOTATION)).getParameters();
if (parameters.isEmpty()) {
return new Object[0];
}
Object[] result = new Object[parameters.size()];
Sort sortToUse = pageable.getSortOr(sort);
for (int i = 0; i < result.length; i++) {
MethodParameter param = parameters.get(i);
Class<?> targetType = param.getParameterType();
if (Pageable.class.isAssignableFrom(targetType)) {
result[i] = pageable;
} else if (Sort.class.isAssignableFrom(targetType)) {
result[i] = sortToUse;
} else {
String parameterName = param.getParameterName();
if (!StringUtils.hasText(parameterName)) {
throw new IllegalArgumentException(String.format(NAME_NOT_FOUND, ClassUtils.getQualifiedMethodName(method)));
}
Object value = unwrapSingleElement(rawParameters.get(parameterName));
result[i] = targetType.isInstance(value) ? value : convert(value, param);
}
}
return result;
}
private Object convert(Object value, MethodParameter parameter) {
try {
return conversionService.convert(value, TypeDescriptor.forObject(value), new TypeDescriptor(parameter));
} catch (ConversionException o_O) {
throw new QueryMethodParameterConversionException(value, parameter, o_O);
}
}
/**
* Invokes the given method with the given arguments on the backing repository.
*
* @param method
* @param arguments
* @return
*/
@SuppressWarnings("unchecked")
private <T> T invoke(Method method, Object... arguments) {
return (T) ReflectionUtils.invokeMethod(method, repository, arguments);
}
@SuppressWarnings("unchecked")
private <T> Optional<T> returnAsOptional(Object source) {
return (Optional<T>) (Optional.class.isInstance(source) ? source
: Optional.ofNullable(QueryExecutionConverters.unwrap(source)));
}
/**
* Converts the given id into the id type of the backing repository.
*
* @param id must not be {@literal null}.
* @return
*/
protected Object convertId(Object id) {
Assert.notNull(id, "Id must not be null!");
return conversionService.convert(id, idType);
}
protected Iterable<Object> invokeFindAllReflectively(Pageable pageable) {
Method method = methods.getFindAllMethod()
.orElseThrow(() -> new IllegalStateException("Repository doesn't have a find-all-method declared!"));
Class<?>[] types = method.getParameterTypes();
if (types.length == 0) {
return invoke(method);
}
if (Pageable.class.isAssignableFrom(types[0])) {
return invoke(method, pageable);
}
return invokeFindAll(pageable.getSort());
}
protected Iterable<Object> invokeFindAllReflectively(Sort sort) {
Method method = methods.getFindAllMethod()
.orElseThrow(() -> new IllegalStateException("Repository doesn't have a find-all-method declared!"));
if (method.getParameterCount() == 0) {
return invoke(method);
}
return invoke(method, sort);
}
/**
* Unwraps the first item if the given source has exactly one element.
*
* @param source can be {@literal null}.
* @return
*/
private static Object unwrapSingleElement(List<? extends Object> source) {
return source == null ? null : source.size() == 1 ? source.get(0) : source;
}
}