/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.search.aggregations.metrics.scripted; import org.apache.lucene.index.LeafReaderContext; import org.elasticsearch.script.ExecutableScript; import org.elasticsearch.script.LeafSearchScript; import org.elasticsearch.script.Script; import org.elasticsearch.script.ScriptContext; import org.elasticsearch.script.ScriptService; import org.elasticsearch.script.SearchScript; import org.elasticsearch.search.SearchParseException; import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.AggregatorFactory; import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.LeafBucketCollector; import org.elasticsearch.search.aggregations.LeafBucketCollectorBase; import org.elasticsearch.search.aggregations.metrics.MetricsAggregator; import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; import org.elasticsearch.search.aggregations.support.AggregationContext; import org.elasticsearch.search.internal.SearchContext; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; public class ScriptedMetricAggregator extends MetricsAggregator { private final SearchScript mapScript; private final ExecutableScript combineScript; private final Script reduceScript; private Map<String, Object> params; protected ScriptedMetricAggregator(String name, Script initScript, Script mapScript, Script combineScript, Script reduceScript, Map<String, Object> params, AggregationContext context, Aggregator parent, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) throws IOException { super(name, context, parent, pipelineAggregators, metaData); this.params = params; ScriptService scriptService = context.searchContext().scriptService(); if (initScript != null) { scriptService.executable(initScript, ScriptContext.Standard.AGGS, context.searchContext(), Collections.<String, String>emptyMap()).run(); } this.mapScript = scriptService.search(context.searchContext().lookup(), mapScript, ScriptContext.Standard.AGGS, Collections.<String, String>emptyMap()); if (combineScript != null) { this.combineScript = scriptService.executable(combineScript, ScriptContext.Standard.AGGS, context.searchContext(), Collections.<String, String>emptyMap()); } else { this.combineScript = null; } this.reduceScript = reduceScript; } @Override public boolean needsScores() { return true; // TODO: how can we know if the script relies on scores? } @Override public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException { final LeafSearchScript leafMapScript = mapScript.getLeafSearchScript(ctx); return new LeafBucketCollectorBase(sub, mapScript) { @Override public void collect(int doc, long bucket) throws IOException { assert bucket == 0 : bucket; leafMapScript.setDocument(doc); leafMapScript.run(); } }; } @Override public InternalAggregation buildAggregation(long owningBucketOrdinal) { Object aggregation; if (combineScript != null) { aggregation = combineScript.run(); } else { aggregation = params.get("_agg"); } return new InternalScriptedMetric(name, aggregation, reduceScript, pipelineAggregators(), metaData()); } @Override public InternalAggregation buildEmptyAggregation() { return new InternalScriptedMetric(name, null, reduceScript, pipelineAggregators(), metaData()); } public static class Factory extends AggregatorFactory { private Script initScript; private Script mapScript; private Script combineScript; private Script reduceScript; private Map<String, Object> params; public Factory(String name, Script initScript, Script mapScript, Script combineScript, Script reduceScript, Map<String, Object> params) { super(name, InternalScriptedMetric.TYPE.name()); this.initScript = initScript; this.mapScript = mapScript; this.combineScript = combineScript; this.reduceScript = reduceScript; this.params = params; } @Override public Aggregator createInternal(AggregationContext context, Aggregator parent, boolean collectsFromSingleBucket, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) throws IOException { if (collectsFromSingleBucket == false) { return asMultiBucketAggregator(this, context, parent); } Map<String, Object> params = this.params; if (params != null) { params = deepCopyParams(params, context.searchContext()); } else { params = new HashMap<>(); params.put("_agg", new HashMap<String, Object>()); } return new ScriptedMetricAggregator(name, insertParams(initScript, params), insertParams(mapScript, params), insertParams( combineScript, params), deepCopyScript(reduceScript, context.searchContext()), params, context, parent, pipelineAggregators, metaData); } private static Script insertParams(Script script, Map<String, Object> params) { if (script == null) { return null; } return new Script(script.getScript(), script.getType(), script.getLang(), params); } private static Script deepCopyScript(Script script, SearchContext context) { if (script != null) { Map<String, Object> params = script.getParams(); if (params != null) { params = deepCopyParams(params, context); } return new Script(script.getScript(), script.getType(), script.getLang(), params); } else { return null; } } @SuppressWarnings({ "unchecked" }) private static <T> T deepCopyParams(T original, SearchContext context) { T clone; if (original instanceof Map) { Map<?, ?> originalMap = (Map<?, ?>) original; Map<Object, Object> clonedMap = new HashMap<>(); for (Entry<?, ?> e : originalMap.entrySet()) { clonedMap.put(deepCopyParams(e.getKey(), context), deepCopyParams(e.getValue(), context)); } clone = (T) clonedMap; } else if (original instanceof List) { List<?> originalList = (List<?>) original; List<Object> clonedList = new ArrayList<Object>(); for (Object o : originalList) { clonedList.add(deepCopyParams(o, context)); } clone = (T) clonedList; } else if (original instanceof String || original instanceof Integer || original instanceof Long || original instanceof Short || original instanceof Byte || original instanceof Float || original instanceof Double || original instanceof Character || original instanceof Boolean) { clone = original; } else { throw new SearchParseException(context, "Can only clone primitives, String, ArrayList, and HashMap. Found: " + original.getClass().getCanonicalName(), null); } return clone; } } }