/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* <p/>
* http://www.apache.org/licenses/LICENSE-2.0
* <p/>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hive.ql.optimizer.spark;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Stack;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.ql.plan.MapWork;
import org.apache.hadoop.hive.ql.plan.PartitionDesc;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.spark.SparkTask;
import org.apache.hadoop.hive.ql.lib.Dispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.TaskGraphWalker;
import org.apache.hadoop.hive.ql.optimizer.OperatorComparatorFactory;
import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalContext;
import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalPlanResolver;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.BaseWork;
import org.apache.hadoop.hive.ql.plan.MapJoinDesc;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.SparkEdgeProperty;
import org.apache.hadoop.hive.ql.plan.SparkWork;
/**
* CombineEquivalentWorkResolver would search inside SparkWork, find and combine equivalent
* works.
*/
public class CombineEquivalentWorkResolver implements PhysicalPlanResolver {
protected static transient Logger LOG = LoggerFactory.getLogger(CombineEquivalentWorkResolver.class);
@Override
public PhysicalContext resolve(PhysicalContext pctx) throws SemanticException {
List<Node> topNodes = new ArrayList<Node>();
topNodes.addAll(pctx.getRootTasks());
TaskGraphWalker taskWalker = new TaskGraphWalker(new EquivalentWorkMatcher());
HashMap<Node, Object> nodeOutput = Maps.newHashMap();
taskWalker.startWalking(topNodes, nodeOutput);
return pctx;
}
class EquivalentWorkMatcher implements Dispatcher {
private Comparator<BaseWork> baseWorkComparator = new Comparator<BaseWork>() {
@Override
public int compare(BaseWork o1, BaseWork o2) {
return o1.getName().compareTo(o2.getName());
}
};
@Override
public Object dispatch(Node nd, Stack<Node> stack, Object... nodeOutputs) throws SemanticException {
if (nd instanceof SparkTask) {
SparkTask sparkTask = (SparkTask) nd;
SparkWork sparkWork = sparkTask.getWork();
Set<BaseWork> roots = sparkWork.getRoots();
compareWorksRecursively(roots, sparkWork);
}
return null;
}
private void compareWorksRecursively(Set<BaseWork> works, SparkWork sparkWork) {
// find out all equivalent works in the Set.
Set<Set<BaseWork>> equivalentWorks = compareChildWorks(works, sparkWork);
// combine equivalent work into single one in SparkWork's work graph.
Set<BaseWork> removedWorks = combineEquivalentWorks(equivalentWorks, sparkWork);
// try to combine next level works recursively.
for (BaseWork work : works) {
if (!removedWorks.contains(work)) {
Set<BaseWork> children = Sets.newHashSet();
children.addAll(sparkWork.getChildren(work));
if (children.size() > 0) {
compareWorksRecursively(children, sparkWork);
}
}
}
}
private Set<Set<BaseWork>> compareChildWorks(Set<BaseWork> children, SparkWork sparkWork) {
Set<Set<BaseWork>> equivalentChildren = Sets.newHashSet();
if (children.size() > 1) {
for (BaseWork work : children) {
boolean assigned = false;
for (Set<BaseWork> set : equivalentChildren) {
if (belongToSet(set, work, sparkWork)) {
set.add(work);
assigned = true;
break;
}
}
if (!assigned) {
// sort the works so that we get consistent query plan for multi executions(for test verification).
Set<BaseWork> newSet = Sets.newTreeSet(baseWorkComparator);
newSet.add(work);
equivalentChildren.add(newSet);
}
}
}
return equivalentChildren;
}
private boolean belongToSet(Set<BaseWork> set, BaseWork work, SparkWork sparkWork) {
if (set.isEmpty()) {
return true;
} else if (compareWork(set.iterator().next(), work, sparkWork)) {
return true;
}
return false;
}
private Set<BaseWork> combineEquivalentWorks(Set<Set<BaseWork>> equivalentWorks, SparkWork sparkWork) {
Set<BaseWork> removedWorks = Sets.newHashSet();
for (Set<BaseWork> workSet : equivalentWorks) {
if (workSet.size() > 1) {
Iterator<BaseWork> iterator = workSet.iterator();
BaseWork first = iterator.next();
while (iterator.hasNext()) {
BaseWork next = iterator.next();
replaceWork(next, first, sparkWork);
removedWorks.add(next);
}
}
}
return removedWorks;
}
private void replaceWork(BaseWork previous, BaseWork current, SparkWork sparkWork) {
updateReference(previous, current, sparkWork);
List<BaseWork> parents = sparkWork.getParents(previous);
List<BaseWork> children = sparkWork.getChildren(previous);
if (parents != null) {
for (BaseWork parent : parents) {
// we do not need to connect its parent to its counterpart, as they have the same parents.
sparkWork.disconnect(parent, previous);
}
}
if (children != null) {
for (BaseWork child : children) {
SparkEdgeProperty edgeProperty = sparkWork.getEdgeProperty(previous, child);
sparkWork.disconnect(previous, child);
sparkWork.connect(current, child, edgeProperty);
}
}
sparkWork.remove(previous);
}
/*
* update the Work name which referred by Operators in following Works.
*/
private void updateReference(BaseWork previous, BaseWork current, SparkWork sparkWork) {
String previousName = previous.getName();
String currentName = current.getName();
List<BaseWork> children = sparkWork.getAllWork();
for (BaseWork child : children) {
Set<Operator<?>> allOperators = child.getAllOperators();
for (Operator<?> operator : allOperators) {
if (operator instanceof MapJoinOperator) {
MapJoinDesc mapJoinDesc = ((MapJoinOperator) operator).getConf();
Map<Integer, String> parentToInput = mapJoinDesc.getParentToInput();
for (Integer id : parentToInput.keySet()) {
String parent = parentToInput.get(id);
if (parent.equals(previousName)) {
parentToInput.put(id, currentName);
}
}
}
}
}
}
private boolean compareWork(BaseWork first, BaseWork second, SparkWork sparkWork) {
if (!first.getClass().getName().equals(second.getClass().getName())) {
return false;
}
if (!hasSameParent(first, second, sparkWork)) {
return false;
}
// leave work's output may be read in further SparkWork/FetchWork, we should not combine
// leave works without notifying further SparkWork/FetchWork.
if (sparkWork.getLeaves().contains(first) && sparkWork.getLeaves().contains(second)) {
return false;
}
// need to check paths and partition desc for MapWorks
if (first instanceof MapWork && !compareMapWork((MapWork) first, (MapWork) second)) {
return false;
}
Set<Operator<?>> firstRootOperators = first.getAllRootOperators();
Set<Operator<?>> secondRootOperators = second.getAllRootOperators();
if (firstRootOperators.size() != secondRootOperators.size()) {
return false;
}
Iterator<Operator<?>> firstIterator = firstRootOperators.iterator();
Iterator<Operator<?>> secondIterator = secondRootOperators.iterator();
while (firstIterator.hasNext()) {
boolean result = compareOperatorChain(firstIterator.next(), secondIterator.next());
if (!result) {
return result;
}
}
return true;
}
private boolean compareMapWork(MapWork first, MapWork second) {
Map<Path, PartitionDesc> pathToPartition1 = first.getPathToPartitionInfo();
Map<Path, PartitionDesc> pathToPartition2 = second.getPathToPartitionInfo();
if (pathToPartition1.size() == pathToPartition2.size()) {
for (Map.Entry<Path, PartitionDesc> entry : pathToPartition1.entrySet()) {
Path path1 = entry.getKey();
PartitionDesc partitionDesc1 = entry.getValue();
PartitionDesc partitionDesc2 = pathToPartition2.get(path1);
if (!partitionDesc1.equals(partitionDesc2)) {
return false;
}
}
return true;
}
return false;
}
private boolean hasSameParent(BaseWork first, BaseWork second, SparkWork sparkWork) {
boolean result = true;
List<BaseWork> firstParents = sparkWork.getParents(first);
List<BaseWork> secondParents = sparkWork.getParents(second);
if (firstParents.size() != secondParents.size()) {
result = false;
}
for (BaseWork parent : firstParents) {
if (!secondParents.contains(parent)) {
result = false;
break;
}
}
return result;
}
private boolean compareOperatorChain(Operator<?> firstOperator, Operator<?> secondOperator) {
boolean result = compareCurrentOperator(firstOperator, secondOperator);
if (!result) {
return result;
}
List<Operator<? extends OperatorDesc>> firstOperatorChildOperators = firstOperator.getChildOperators();
List<Operator<? extends OperatorDesc>> secondOperatorChildOperators = secondOperator.getChildOperators();
if (firstOperatorChildOperators == null && secondOperatorChildOperators != null) {
return false;
} else if (firstOperatorChildOperators != null && secondOperatorChildOperators == null) {
return false;
} else if (firstOperatorChildOperators != null && secondOperatorChildOperators != null) {
if (firstOperatorChildOperators.size() != secondOperatorChildOperators.size()) {
return false;
}
int size = firstOperatorChildOperators.size();
for (int i = 0; i < size; i++) {
result = compareOperatorChain(firstOperatorChildOperators.get(i), secondOperatorChildOperators.get(i));
if (!result) {
return false;
}
}
}
return true;
}
/**
* Compare Operators through their Explain output string.
*
* @param firstOperator
* @param secondOperator
* @return
*/
private boolean compareCurrentOperator(Operator<?> firstOperator, Operator<?> secondOperator) {
if (!firstOperator.getClass().getName().equals(secondOperator.getClass().getName())) {
return false;
}
OperatorComparatorFactory.OperatorComparator operatorComparator =
OperatorComparatorFactory.getOperatorComparator(firstOperator.getClass());
return operatorComparator.equals(firstOperator, secondOperator);
}
}
}