/*
* Copyright (c) 2006-2013 Rogério Liesenfeld
* This file is subject to the terms of the MIT license (see LICENSE.txt).
*/
package mockit.internal.expectations;
import java.io.*;
import java.lang.reflect.*;
import java.nio.*;
import java.util.*;
import mockit.internal.expectations.invocation.*;
import mockit.internal.util.*;
final class ReturnTypeConversion
{
private final Expectation expectation;
private final Class<?> returnType;
private final Object value;
ReturnTypeConversion(Expectation expectation, Class<?> returnType, Object value)
{
this.expectation = expectation;
this.returnType = returnType;
this.value = value;
}
void addConvertedValueOrValues()
{
boolean valueIsArray = value != null && value.getClass().isArray();
boolean valueIsIterable = value instanceof Iterable<?>;
if (valueIsArray || valueIsIterable || value instanceof Iterator<?>) {
if (returnType == void.class || hasReturnOfDifferentType()) {
if (valueIsArray) {
expectation.getResults().addReturnValues(value);
}
else if (valueIsIterable) {
expectation.getResults().addReturnValues((Iterable<?>) value);
}
else {
expectation.getResults().addDeferredReturnValues((Iterator<?>) value);
}
return;
}
}
expectation.substituteCascadedMockToBeReturnedIfNeeded(value);
expectation.getResults().addReturnValue(value);
}
private boolean hasReturnOfDifferentType()
{
return
!returnType.isArray() &&
!Iterable.class.isAssignableFrom(returnType) && !Iterator.class.isAssignableFrom(returnType) &&
!returnType.isAssignableFrom(value.getClass());
}
void addConvertedValue()
{
Class<?> wrapperType =
AutoBoxing.isWrapperOfPrimitiveType(returnType) ? returnType : AutoBoxing.getWrapperType(returnType);
Class<?> valueType = value.getClass();
if (valueType == wrapperType) {
expectation.getResults().addReturnValueResult(value);
}
else if (wrapperType != null && AutoBoxing.isWrapperOfPrimitiveType(valueType)) {
addPrimitiveValueConvertingAsNeeded(wrapperType);
}
else {
boolean valueIsArray = valueType.isArray();
if (valueIsArray || value instanceof Iterable<?> || value instanceof Iterator<?>) {
addMultiValuedResultBasedOnTheReturnType(valueIsArray);
}
else if (wrapperType != null) {
throw newIncompatibleTypesException();
}
else {
addResultFromSingleValue();
}
}
}
private void addMultiValuedResultBasedOnTheReturnType(boolean valueIsArray)
{
if (returnType == void.class) {
addMultiValuedResult(valueIsArray);
}
else if (returnType == Object.class) {
expectation.getResults().addReturnValueResult(value);
}
else if (valueIsArray && addCollectionOrMapWithElementsFromArray()) {
return;
}
else if (hasReturnOfDifferentType()) {
addMultiValuedResult(valueIsArray);
}
else {
expectation.getResults().addReturnValueResult(value);
}
}
private void addMultiValuedResult(boolean valueIsArray)
{
if (valueIsArray) {
expectation.getResults().addResults(value);
}
else if (value instanceof Iterable<?>) {
expectation.getResults().addResults((Iterable<?>) value);
}
else {
expectation.getResults().addDeferredResults((Iterator<?>) value);
}
}
private boolean addCollectionOrMapWithElementsFromArray()
{
int n = Array.getLength(value);
Object values = null;
if (returnType.isAssignableFrom(ListIterator.class)) {
List<Object> list = new ArrayList<Object>(n);
addArrayElements(list, n);
values = list.listIterator();
}
else if (returnType.isAssignableFrom(List.class)) {
values = addArrayElements(new ArrayList<Object>(n), n);
}
else if (returnType.isAssignableFrom(Set.class)) {
values = addArrayElements(new LinkedHashSet<Object>(n), n);
}
else if (returnType.isAssignableFrom(SortedSet.class)) {
values = addArrayElements(new TreeSet<Object>(), n);
}
else if (returnType.isAssignableFrom(Map.class)) {
values = addArrayElements(new LinkedHashMap<Object, Object>(n), n);
}
else if (returnType.isAssignableFrom(SortedMap.class)) {
values = addArrayElements(new TreeMap<Object, Object>(), n);
}
if (values != null) {
expectation.getResults().addReturnValue(values);
return true;
}
return false;
}
private Object addArrayElements(Collection<Object> values, int elementCount)
{
for (int i = 0; i < elementCount; i++) {
Object element = Array.get(value, i);
values.add(element);
}
return values;
}
private Object addArrayElements(Map<Object, Object> values, int elementPairCount)
{
for (int i = 0; i < elementPairCount; i++) {
Object keyAndValue = Array.get(value, i);
if (keyAndValue == null || !keyAndValue.getClass().isArray()) {
return null;
}
Object key = Array.get(keyAndValue, 0);
Object element = Array.getLength(keyAndValue) > 1 ? Array.get(keyAndValue, 1) : null;
values.put(key, element);
}
return values;
}
private void addResultFromSingleValue()
{
if (returnType == Object.class) {
expectation.getResults().addReturnValueResult(value);
}
else if (returnType == void.class) {
throw newIncompatibleTypesException();
}
else if (returnType.isArray()) {
Object array = Array.newInstance(returnType.getComponentType(), 1);
Array.set(array, 0, value);
expectation.getResults().addReturnValueResult(array);
}
else if (returnType.isAssignableFrom(ArrayList.class)) {
addCollectionWithSingleElement(new ArrayList<Object>(1));
}
else if (returnType.isAssignableFrom(LinkedList.class)) {
addCollectionWithSingleElement(new LinkedList<Object>());
}
else if (returnType.isAssignableFrom(HashSet.class)) {
addCollectionWithSingleElement(new HashSet<Object>(1));
}
else if (returnType.isAssignableFrom(TreeSet.class)) {
addCollectionWithSingleElement(new TreeSet<Object>());
}
else if (returnType.isAssignableFrom(ListIterator.class)) {
List<Object> l = new ArrayList<Object>(1);
l.add(value);
expectation.getResults().addReturnValueResult(l.listIterator());
}
else if (value instanceof CharSequence) {
addCharSequence((CharSequence) value);
}
else {
throw newIncompatibleTypesException();
}
}
private void addCollectionWithSingleElement(Collection<Object> container)
{
container.add(value);
expectation.getResults().addReturnValueResult(container);
}
private void addCharSequence(CharSequence value)
{
Object convertedValue = value;
if (returnType.isAssignableFrom(ByteArrayInputStream.class)) {
convertedValue = new ByteArrayInputStream(value.toString().getBytes());
}
else if (returnType.isAssignableFrom(StringReader.class)) {
convertedValue = new StringReader(value.toString());
}
else if (!(value instanceof StringBuilder) && returnType.isAssignableFrom(StringBuilder.class)) {
convertedValue = new StringBuilder(value);
}
else if (!(value instanceof CharBuffer) && returnType.isAssignableFrom(CharBuffer.class)) {
convertedValue = CharBuffer.wrap(value);
}
expectation.getResults().addReturnValueResult(convertedValue);
}
private IllegalArgumentException newIncompatibleTypesException()
{
ExpectedInvocation invocation = expectation.invocation;
String valueTypeName = value.getClass().getName().replace("java.lang.", "");
String returnTypeName = returnType.getName().replace("java.lang.", "");
StringBuilder msg = new StringBuilder(200);
msg.append("Value of type ").append(valueTypeName);
msg.append(" incompatible with return type ").append(returnTypeName).append(" of ");
msg.append(new MethodFormatter(invocation.getClassDesc(), invocation.getMethodNameAndDescription()));
return new IllegalArgumentException(msg.toString());
}
private void addPrimitiveValueConvertingAsNeeded(Class<?> targetType)
{
Object convertedValue = null;
if (value instanceof Number) {
convertedValue = convertFromNumber(targetType, (Number) value);
}
else if (value instanceof Character) {
convertedValue = convertFromChar(targetType, (Character) value);
}
if (convertedValue == null) {
throw newIncompatibleTypesException();
}
expectation.getResults().addReturnValueResult(convertedValue);
}
private Object convertFromNumber(Class<?> targetType, Number number)
{
if (targetType == Integer.class) {
return number.intValue();
}
else if (targetType == Short.class) {
return number.shortValue();
}
else if (targetType == Long.class) {
return number.longValue();
}
else if (targetType == Byte.class) {
return number.byteValue();
}
else if (targetType == Double.class) {
return number.doubleValue();
}
else if (targetType == Float.class) {
return number.floatValue();
}
else if (targetType == Character.class) {
return (char) number.intValue();
}
return null;
}
private Object convertFromChar(Class<?> targetType, char c)
{
if (targetType == Integer.class) {
return (int) c;
}
else if (targetType == Short.class) {
return (short) c;
}
else if (targetType == Long.class) {
return (long) c;
}
else if (targetType == Byte.class) {
return (byte) c;
}
else if (targetType == Double.class) {
return (double) c;
}
else if (targetType == Float.class) {
return (float) c;
}
return null;
}
}