/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.hadoop.hive.ql.exec; import java.util.ArrayList; import java.util.Collection; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import org.apache.hadoop.hive.ql.exec.NodeUtils.Function; import org.apache.hadoop.hive.ql.plan.MapJoinDesc; import org.apache.hadoop.hive.ql.plan.OperatorDesc; import org.apache.hadoop.mapred.OutputCollector; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.Multimap; public class OperatorUtils { private static final Logger LOG = LoggerFactory.getLogger(OperatorUtils.class); public static <T> Set<T> findOperators(Operator<?> start, Class<T> clazz) { return findOperators(start, clazz, new HashSet<T>()); } public static <T> T findSingleOperator(Operator<?> start, Class<T> clazz) { Set<T> found = findOperators(start, clazz, new HashSet<T>()); return found.size() == 1 ? found.iterator().next() : null; } public static <T> Set<T> findOperators(Collection<Operator<?>> starts, Class<T> clazz) { Set<T> found = new HashSet<T>(); for (Operator<?> start : starts) { if (start == null) { continue; } findOperators(start, clazz, found); } return found; } @SuppressWarnings("unchecked") private static <T> Set<T> findOperators(Operator<?> start, Class<T> clazz, Set<T> found) { if (clazz.isInstance(start)) { found.add((T) start); } if (start.getChildOperators() != null) { for (Operator<?> child : start.getChildOperators()) { findOperators(child, clazz, found); } } return found; } public static <T> Set<T> findOperatorsUpstream(Operator<?> start, Class<T> clazz) { return findOperatorsUpstream(start, clazz, new HashSet<T>()); } public static <T> T findSingleOperatorUpstream(Operator<?> start, Class<T> clazz) { Set<T> found = findOperatorsUpstream(start, clazz, new HashSet<T>()); return found.size() == 1 ? found.iterator().next() : null; } public static <T> T findSingleOperatorUpstreamJoinAccounted(Operator<?> start, Class<T> clazz) { Set<T> found = findOperatorsUpstreamJoinAccounted(start, clazz, new HashSet<T>()); return found.size() >= 1 ? found.iterator().next(): null; } public static <T> Set<T> findOperatorsUpstream(Collection<Operator<?>> starts, Class<T> clazz) { Set<T> found = new HashSet<T>(); for (Operator<?> start : starts) { findOperatorsUpstream(start, clazz, found); } return found; } @SuppressWarnings("unchecked") private static <T> Set<T> findOperatorsUpstream(Operator<?> start, Class<T> clazz, Set<T> found) { if (clazz.isInstance(start)) { found.add((T) start); } if (start.getParentOperators() != null) { for (Operator<?> parent : start.getParentOperators()) { findOperatorsUpstream(parent, clazz, found); } } return found; } public static <T> Set<T> findOperatorsUpstreamJoinAccounted(Operator<?> start, Class<T> clazz, Set<T> found) { if (clazz.isInstance(start)) { found.add((T) start); } int onlyIncludeIndex = -1; if (start instanceof AbstractMapJoinOperator) { AbstractMapJoinOperator mapJoinOp = (AbstractMapJoinOperator) start; MapJoinDesc desc = (MapJoinDesc) mapJoinOp.getConf(); onlyIncludeIndex = desc.getPosBigTable(); } if (start.getParentOperators() != null) { int i = 0; for (Operator<?> parent : start.getParentOperators()) { if (onlyIncludeIndex >= 0) { if (onlyIncludeIndex == i) { findOperatorsUpstream(parent, clazz, found); } } else { findOperatorsUpstream(parent, clazz, found); } i++; } } return found; } public static void setChildrenCollector(List<Operator<? extends OperatorDesc>> childOperators, OutputCollector out) { if (childOperators == null) { return; } for (Operator<? extends OperatorDesc> op : childOperators) { if (op.getName().equals(ReduceSinkOperator.getOperatorName())) { op.setOutputCollector(out); } else { setChildrenCollector(op.getChildOperators(), out); } } } public static void setChildrenCollector(List<Operator<? extends OperatorDesc>> childOperators, Map<String, OutputCollector> outMap) { if (childOperators == null) { return; } for (Operator<? extends OperatorDesc> op : childOperators) { if (op.getIsReduceSink()) { String outputName = op.getReduceOutputName(); if (outMap.containsKey(outputName)) { LOG.info("Setting output collector: " + op + " --> " + outputName); op.setOutputCollector(outMap.get(outputName)); } } else { setChildrenCollector(op.getChildOperators(), outMap); } } } /** * Starting at the input operator, finds the last operator in the stream that * is an instance of the input class. * * @param op the starting operator * @param clazz the class that the operator that we are looking for instantiates * @return null if no such operator exists or multiple branches are found in * the stream, the last operator otherwise */ @SuppressWarnings("unchecked") public static <T> T findLastOperator(Operator<?> op, Class<T> clazz) { Operator<?> currentOp = op; T lastOp = null; while (currentOp != null) { if (clazz.isInstance(currentOp)) { lastOp = (T) currentOp; } if (currentOp.getChildOperators().size() == 1) { currentOp = currentOp.getChildOperators().get(0); } else { currentOp = null; } } return lastOp; } /** * Starting at the input operator, finds the last operator upstream that is * an instance of the input class. * * @param op the starting operator * @param clazz the class that the operator that we are looking for instantiates * @return null if no such operator exists or multiple branches are found in * the stream, the last operator otherwise */ @SuppressWarnings("unchecked") public static <T> T findLastOperatorUpstream(Operator<?> op, Class<T> clazz) { Operator<?> currentOp = op; T lastOp = null; while (currentOp != null) { if (clazz.isInstance(currentOp)) { lastOp = (T) currentOp; } if (currentOp.getParentOperators().size() == 1) { currentOp = currentOp.getParentOperators().get(0); } else { currentOp = null; } } return lastOp; } public static void iterateParents(Operator<?> operator, Function<Operator<?>> function) { iterateParents(operator, function, new HashSet<Operator<?>>()); } private static void iterateParents(Operator<?> operator, Function<Operator<?>> function, Set<Operator<?>> visited) { if (!visited.add(operator)) { return; } function.apply(operator); if (operator.getNumParent() > 0) { for (Operator<?> parent : operator.getParentOperators()) { iterateParents(parent, function, visited); } } } public static boolean sameRowSchema(Operator<?> operator1, Operator<?> operator2) { return operator1.getSchema().equals(operator2.getSchema()); } /** * Given an operator and a set of classes, it classifies the operators it finds * in the stream depending on the classes they instantiate. * * If a given operator object is an instance of more than one of the input classes, * e.g. the operator instantiates one of the classes in the input set that is a * subclass of another class in the set, the operator will be associated to both * classes in the output map. * * @param start the start operator * @param classes the set of classes * @return a multimap from each of the classes to the operators that instantiate * them */ public static Multimap<Class<? extends Operator<?>>, Operator<?>> classifyOperators( Operator<?> start, Set<Class<? extends Operator<?>>> classes) { ImmutableMultimap.Builder<Class<? extends Operator<?>>, Operator<?>> resultMap = new ImmutableMultimap.Builder<Class<? extends Operator<?>>, Operator<?>>(); List<Operator<?>> ops = new ArrayList<Operator<?>>(); ops.add(start); while (!ops.isEmpty()) { List<Operator<?>> allChildren = new ArrayList<Operator<?>>(); for (Operator<?> op: ops) { for (Class<? extends Operator<?>> clazz: classes) { if (clazz.isInstance(op)) { resultMap.put(clazz, op); } } allChildren.addAll(op.getChildOperators()); } ops = allChildren; } return resultMap.build(); } /** * Given an operator and a set of classes, it classifies the operators it finds * upstream depending on the classes it instantiates. * * If a given operator object is an instance of more than one of the input classes, * e.g. the operator instantiates one of the classes in the input set that is a * subclass of another class in the set, the operator will be associated to both * classes in the output map. * * @param start the start operator * @param classes the set of classes * @return a multimap from each of the classes to the operators that instantiate * them */ public static Multimap<Class<? extends Operator<?>>, Operator<?>> classifyOperatorsUpstream( Operator<?> start, Set<Class<? extends Operator<?>>> classes) { ImmutableMultimap.Builder<Class<? extends Operator<?>>, Operator<?>> resultMap = new ImmutableMultimap.Builder<Class<? extends Operator<?>>, Operator<?>>(); List<Operator<?>> ops = new ArrayList<Operator<?>>(); ops.add(start); while (!ops.isEmpty()) { List<Operator<?>> allParent = new ArrayList<Operator<?>>(); for (Operator<?> op: ops) { for (Class<? extends Operator<?>> clazz: classes) { if (clazz.isInstance(op)) { resultMap.put(clazz, op); } } if (op.getParentOperators() != null) { allParent.addAll(op.getParentOperators()); } } ops = allParent; } return resultMap.build(); } /** * Given an operator and a set of classes, it returns the number of operators it finds * upstream that instantiate any of the given classes. * * @param start the start operator * @param classes the set of classes * @return the number of operators */ public static int countOperatorsUpstream(Operator<?> start, Set<Class<? extends Operator<?>>> classes) { Multimap<Class<? extends Operator<?>>, Operator<?>> ops = classifyOperatorsUpstream(start, classes); int numberOperators = 0; Set<Operator<?>> uniqueOperators = new HashSet<Operator<?>>(); for (Operator<?> op : ops.values()) { if (uniqueOperators.add(op)) { numberOperators++; } } return numberOperators; } public static void setMemoryAvailable(final List<Operator<? extends OperatorDesc>> operators, final long memoryAvailableToTask) { if (operators == null) { return; } for (Operator<? extends OperatorDesc> op : operators) { if (op.getConf() != null) { op.getConf().setMaxMemoryAvailable(memoryAvailableToTask); } if (op.getChildOperators() != null && !op.getChildOperators().isEmpty()) { setMemoryAvailable(op.getChildOperators(), memoryAvailableToTask); } } } /** * Given the input operator 'op', walk up the operator tree from 'op', and collect all the * roots that can be reached from it. The results are stored in 'roots'. */ public static void findRoots(Operator<?> op, Collection<Operator<?>> roots) { List<Operator<?>> parents = op.getParentOperators(); if (parents == null || parents.isEmpty()) { roots.add(op); return; } for (Operator<?> p : parents) { findRoots(p, roots); } } /** * Remove the branch that contains the specified operator. Do nothing if there's no branching, * i.e. all the upstream operators have only one child. */ public static void removeBranch(Operator<?> op) { Operator<?> child = op; Operator<?> curr = op; while (curr.getChildOperators().size() <= 1) { child = curr; if (curr.getParentOperators() == null || curr.getParentOperators().isEmpty()) { return; } curr = curr.getParentOperators().get(0); } curr.removeChild(child); } public static String getOpNamePretty(Operator<?> op) { if (op instanceof TableScanOperator) { return op.toString() + " (" + ((TableScanOperator) op).getConf().getAlias() + ")"; } return op.toString(); } }