package graphql.execution.batched; import graphql.ExecutionResult; import graphql.ExecutionResultImpl; import graphql.GraphQLException; import graphql.execution.ExecutionContext; import graphql.execution.ExecutionParameters; import graphql.execution.ExecutionStrategy; import graphql.execution.FieldCollectorParameters; import graphql.execution.TypeResolutionParameters; import graphql.language.Field; import graphql.schema.DataFetcher; import graphql.schema.DataFetchingEnvironment; import graphql.schema.DataFetchingEnvironmentImpl; import graphql.schema.DataFetchingFieldSelectionSet; import graphql.schema.DataFetchingFieldSelectionSetImpl; import graphql.schema.GraphQLEnumType; import graphql.schema.GraphQLFieldDefinition; import graphql.schema.GraphQLInterfaceType; import graphql.schema.GraphQLList; import graphql.schema.GraphQLNonNull; import graphql.schema.GraphQLObjectType; import graphql.schema.GraphQLOutputType; import graphql.schema.GraphQLScalarType; import graphql.schema.GraphQLType; import graphql.schema.GraphQLUnionType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Queue; import static graphql.execution.FieldCollectorParameters.newParameters; import static java.util.Collections.singletonList; /** * Execution Strategy that minimizes calls to the data fetcher when used in conjunction with {@link DataFetcher}s that have * {@link DataFetcher#get(DataFetchingEnvironment)} methods annotated with {@link Batched}. See the javadoc comment on * {@link Batched} for a more detailed description of batched data fetchers. * <p> * The strategy runs a BFS over terms of the query and passes a list of all the relevant sources to the batched data fetcher. * </p> * Normal DataFetchers can be used, however they will not see benefits of batching as they expect a single source object * at a time. */ public class BatchedExecutionStrategy extends ExecutionStrategy { private static final Logger log = LoggerFactory.getLogger(BatchedExecutionStrategy.class); private final BatchedDataFetcherFactory batchingFactory = new BatchedDataFetcherFactory(); @Override public ExecutionResult execute(ExecutionContext executionContext, ExecutionParameters parameters) { GraphQLExecutionNodeDatum data = new GraphQLExecutionNodeDatum(new LinkedHashMap<>(), parameters.source()); GraphQLObjectType type = parameters.typeInfo().castType(GraphQLObjectType.class); GraphQLExecutionNode root = new GraphQLExecutionNode(type, parameters.fields(), singletonList(data)); return execute(executionContext, root); } private ExecutionResult execute(ExecutionContext executionContext, GraphQLExecutionNode root) { Queue<GraphQLExecutionNode> nodes = new ArrayDeque<>(); nodes.add(root); while (!nodes.isEmpty()) { GraphQLExecutionNode node = nodes.poll(); for (String fieldName : node.getFields().keySet()) { List<Field> fieldList = node.getFields().get(fieldName); List<GraphQLExecutionNode> childNodes = resolveField(executionContext, node.getParentType(), node.getData(), fieldName, fieldList); nodes.addAll(childNodes); } } return new ExecutionResultImpl(getOnlyElement(root.getData()).getParentResult(), executionContext.getErrors()); } private GraphQLExecutionNodeDatum getOnlyElement(List<GraphQLExecutionNodeDatum> list) { return list.get(0); } // Use the data.source objects to fetch // Use the data.parentResult objects to put values into. These are either primitives or empty maps // If they were empty maps, we need that list of nodes to process private List<GraphQLExecutionNode> resolveField(ExecutionContext executionContext, GraphQLObjectType parentType, List<GraphQLExecutionNodeDatum> nodeData, String fieldName, List<Field> fields) { GraphQLFieldDefinition fieldDef = getFieldDef(executionContext.getGraphQLSchema(), parentType, fields.get(0)); if (fieldDef == null) { return Collections.emptyList(); } List<GraphQLExecutionNodeValue> values = fetchData(executionContext, parentType, nodeData, fields, fieldDef); Map<String, Object> argumentValues = valuesResolver.getArgumentValues( fieldDef.getArguments(), fields.get(0).getArguments(), executionContext.getVariables()); return completeValues(executionContext, parentType, values, fieldName, fields, fieldDef.getType(), argumentValues); } /** * Updates parents and returns new Nodes. */ private List<GraphQLExecutionNode> completeValues(ExecutionContext executionContext, GraphQLObjectType parentType, List<GraphQLExecutionNodeValue> values, String fieldName, List<Field> fields, GraphQLOutputType outputType, Map<String, Object> argumentValues) { GraphQLType fieldType = handleNonNullType(outputType, values, parentType, fields); if (isPrimitive(fieldType)) { handlePrimitives(values, fieldName, fieldType); return Collections.emptyList(); } else if (isObject(fieldType)) { return handleObject(executionContext, argumentValues, values, fieldName, fields, fieldType); } else if (isList(fieldType)) { return handleList(executionContext, argumentValues, values, fieldName, fields, parentType, (GraphQLList) fieldType); } else { throw new IllegalArgumentException("Unrecognized type: " + fieldType); } } @SuppressWarnings("unchecked") private List<GraphQLExecutionNode> handleList(ExecutionContext executionContext, Map<String, Object> argumentValues, List<GraphQLExecutionNodeValue> values, String fieldName, List<Field> fields, GraphQLObjectType parentType, GraphQLList listType) { List<GraphQLExecutionNodeValue> flattenedNodeValues = new ArrayList<>(); for (GraphQLExecutionNodeValue value : values) { if (value.getValue() == null) { value.getResultContainer().putResult(fieldName, null); } else { GraphQLExecutionResultList flattenedDatum = value.getResultContainer().createAndPutEmptyChildList( fieldName); for (Object rawValue : (List<Object>) value.getValue()) { flattenedNodeValues.add(new GraphQLExecutionNodeValue(flattenedDatum, rawValue)); } } } GraphQLOutputType subType = (GraphQLOutputType) listType.getWrappedType(); return completeValues(executionContext, parentType, flattenedNodeValues, fieldName, fields, subType, argumentValues); } @SuppressWarnings("UnnecessaryLocalVariable") private List<GraphQLExecutionNode> handleObject(ExecutionContext executionContext, Map<String, Object> argumentValues, List<GraphQLExecutionNodeValue> values, String fieldName, List<Field> fields, GraphQLType fieldType) { ChildDataCollector collector = createAndPopulateChildData(executionContext, fields.get(0), values, fieldName, fieldType, argumentValues); List<GraphQLExecutionNode> childNodes = createChildNodes(executionContext, fields, collector); return childNodes; } private List<GraphQLExecutionNode> createChildNodes(ExecutionContext executionContext, List<Field> fields, ChildDataCollector collector) { List<GraphQLExecutionNode> childNodes = new ArrayList<>(); for (ChildDataCollector.Entry entry : collector.getEntries()) { Map<String, List<Field>> childFields = getChildFields(executionContext, entry.getObjectType(), fields); childNodes.add(new GraphQLExecutionNode(entry.getObjectType(), childFields, entry.getData())); } return childNodes; } private ChildDataCollector createAndPopulateChildData(ExecutionContext executionContext, Field field, List<GraphQLExecutionNodeValue> values, String fieldName, GraphQLType fieldType, Map<String, Object> argumentValues) { ChildDataCollector collector = new ChildDataCollector(); for (GraphQLExecutionNodeValue value : values) { if (value.getValue() == null) { // We hit a null, insert the null and do not create a child value.getResultContainer().putResult(fieldName, null); } else { GraphQLExecutionNodeDatum childDatum = value.getResultContainer().createAndPutChildDatum(fieldName, value.getValue()); GraphQLObjectType graphQLObjectType = getGraphQLObjectType(executionContext, field, fieldType, value.getValue(), argumentValues); collector.putChildData(graphQLObjectType, childDatum); } } return collector; } private GraphQLType handleNonNullType(GraphQLType fieldType, List<GraphQLExecutionNodeValue> values, /*Nullable*/ GraphQLObjectType parentType, /*Nullable*/ List<Field> fields) { if (isNonNull(fieldType)) { for (GraphQLExecutionNodeValue value : values) { if (value.getValue() == null) { throw new GraphQLException("Found null value for non-null type with parent: '" + parentType.getName() + "' for fields: " + fields); } } while (isNonNull(fieldType)) { fieldType = ((GraphQLNonNull) fieldType).getWrappedType(); } } return fieldType; } private boolean isNonNull(GraphQLType fieldType) { return fieldType instanceof GraphQLNonNull; } private Map<String, List<Field>> getChildFields(ExecutionContext executionContext, GraphQLObjectType resolvedType, List<Field> fields) { FieldCollectorParameters collectorParameters = newParameters(executionContext.getGraphQLSchema(), resolvedType) .fragments(executionContext.getFragmentsByName()) .variables(executionContext.getVariables()) .build(); return fieldCollector.collectFields(collectorParameters, fields); } private GraphQLObjectType getGraphQLObjectType(ExecutionContext executionContext, Field field, GraphQLType fieldType, Object value, Map<String, Object> argumentValues) { GraphQLObjectType resolvedType = null; if (fieldType instanceof GraphQLInterfaceType) { resolvedType = resolveTypeForInterface(TypeResolutionParameters.newParameters() .graphQLInterfaceType((GraphQLInterfaceType) fieldType) .field(field) .value(value) .argumentValues(argumentValues) .schema(executionContext.getGraphQLSchema()) .build()); } else if (fieldType instanceof GraphQLUnionType) { resolvedType = resolveTypeForUnion(TypeResolutionParameters.newParameters() .graphQLUnionType((GraphQLUnionType) fieldType) .field(field) .value(value) .argumentValues(argumentValues) .schema(executionContext.getGraphQLSchema()) .build()); } else if (fieldType instanceof GraphQLObjectType) { resolvedType = (GraphQLObjectType) fieldType; } return resolvedType; } private void handlePrimitives(List<GraphQLExecutionNodeValue> values, String fieldName, GraphQLType type) { for (GraphQLExecutionNodeValue value : values) { Object coercedValue = coerce(type, value.getValue()); //6.6.1 http://facebook.github.io/graphql/#sec-Field-entries if (coercedValue instanceof Double && ((Double) coercedValue).isNaN()) { coercedValue = null; } value.getResultContainer().putResult(fieldName, coercedValue); } } private Object coerce(GraphQLType type, Object value) { if (type instanceof GraphQLEnumType) { return ((GraphQLEnumType) type).getCoercing().serialize(value); } else { return ((GraphQLScalarType) type).getCoercing().serialize(value); } } private boolean isList(GraphQLType type) { return type instanceof GraphQLList; } private boolean isPrimitive(GraphQLType type) { return type instanceof GraphQLScalarType || type instanceof GraphQLEnumType; } private boolean isObject(GraphQLType type) { return type instanceof GraphQLObjectType || type instanceof GraphQLInterfaceType || type instanceof GraphQLUnionType; } @SuppressWarnings("unchecked") private List<GraphQLExecutionNodeValue> fetchData(ExecutionContext executionContext, GraphQLObjectType parentType, List<GraphQLExecutionNodeDatum> nodeData, List<Field> fields, GraphQLFieldDefinition fieldDef) { Map<String, Object> argumentValues = valuesResolver.getArgumentValues( fieldDef.getArguments(), fields.get(0).getArguments(), executionContext.getVariables()); List<Object> sources = new ArrayList<>(); for (GraphQLExecutionNodeDatum n : nodeData) { sources.add(n.getSource()); } GraphQLOutputType fieldType = fieldDef.getType(); DataFetchingFieldSelectionSet fieldCollector = DataFetchingFieldSelectionSetImpl.newCollector(executionContext, fieldType, fields); DataFetchingEnvironment environment = new DataFetchingEnvironmentImpl( sources, argumentValues, executionContext.getRoot(), fields, fieldDef.getType(), parentType, executionContext.getGraphQLSchema(), executionContext.getFragmentsByName(), executionContext.getExecutionId(), fieldCollector); List<Object> values; try { values = (List<Object>) getDataFetcher(fieldDef).get(environment); } catch (Exception e) { values = new ArrayList<>(nodeData.size()); log.warn("Exception while fetching data", e); handleDataFetchingException(executionContext, fieldDef, argumentValues, e); } assert nodeData.size() == values.size(); List<GraphQLExecutionNodeValue> retVal = new ArrayList<>(); for (int i = 0; i < nodeData.size(); i++) { retVal.add(new GraphQLExecutionNodeValue(nodeData.get(i), values.get(i))); } return retVal; } private BatchedDataFetcher getDataFetcher(GraphQLFieldDefinition fieldDef) { DataFetcher supplied = fieldDef.getDataFetcher(); return batchingFactory.create(supplied); } }