/* * Copyright (c) 2013, 2014, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it * under the terms of the GNU General Public License version 2 only, as * published by the Free Software Foundation. * * This code is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License * version 2 for more details (a copy is included in the LICENSE file that * accompanied this code). * * You should have received a copy of the GNU General Public License version * 2 along with this work; if not, write to the Free Software Foundation, * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. * * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA * or visit www.oracle.com if you need additional information or have any * questions. */ package org.graalvm.compiler.truffle; import static org.graalvm.compiler.truffle.TruffleCompilerOptions.TruffleFunctionInlining; import static org.graalvm.compiler.truffle.TruffleCompilerOptions.TruffleInliningMaxCallerSize; import static org.graalvm.compiler.truffle.TruffleCompilerOptions.TruffleMaximumRecursiveInlining; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.NoSuchElementException; import com.oracle.truffle.api.CallTarget; import com.oracle.truffle.api.CompilerOptions; import com.oracle.truffle.api.nodes.Node; import com.oracle.truffle.api.nodes.NodeUtil; import com.oracle.truffle.api.nodes.NodeVisitor; public class TruffleInlining implements Iterable<TruffleInliningDecision> { private final List<TruffleInliningDecision> callSites; protected TruffleInlining(List<TruffleInliningDecision> callSites) { this.callSites = callSites; } public TruffleInlining(OptimizedCallTarget sourceTarget, TruffleInliningPolicy policy) { this(createDecisions(sourceTarget, policy, sourceTarget.getRootNode().getCompilerOptions())); } private static List<TruffleInliningDecision> createDecisions(OptimizedCallTarget sourceTarget, TruffleInliningPolicy policy, CompilerOptions options) { if (!TruffleCompilerOptions.getValue(TruffleFunctionInlining)) { return Collections.emptyList(); } int[] visitedNodes = {0}; int nodeCount = sourceTarget.getNonTrivialNodeCount(); List<TruffleInliningDecision> exploredCallSites = exploreCallSites(new ArrayList<>(Arrays.asList(sourceTarget)), nodeCount, policy, visitedNodes, new HashMap<>()); return decideInlining(exploredCallSites, policy, nodeCount, options); } private static List<TruffleInliningDecision> exploreCallSites(List<OptimizedCallTarget> stack, int callStackNodeCount, TruffleInliningPolicy policy, int[] visitedNodes, Map<OptimizedCallTarget, TruffleInliningDecision> rejectedDecisionsCache) { List<TruffleInliningDecision> exploredCallSites = new ArrayList<>(); List<OptimizedCallTarget> toRemoveFromCache = new LinkedList<>(); OptimizedCallTarget parentTarget = stack.get(stack.size() - 1); for (OptimizedDirectCallNode callNode : getCallNodes(parentTarget)) { OptimizedCallTarget currentTarget = callNode.getCurrentCallTarget(); stack.add(currentTarget); // push TruffleInliningDecision decision = rejectedDecisionsCache.get(currentTarget); if (decision == null) { // Cache miss decision = exploreCallSite(stack, callStackNodeCount, policy, callNode, visitedNodes, rejectedDecisionsCache); if (!policy.isAllowed(decision.getProfile(), callStackNodeCount, callNode.getRootNode().getCompilerOptions())) { rejectedDecisionsCache.put(currentTarget, decision); toRemoveFromCache.add(currentTarget); } } else { // Cache hit! TruffleInliningProfile cachedProfile = decision.getProfile(); TruffleInliningProfile newProfile = new TruffleInliningProfile(callNode, cachedProfile.getNodeCount(), cachedProfile.getDeepNodeCount(), cachedProfile.getFrequency(), cachedProfile.getRecursions()); newProfile.setCached(true); TruffleInliningDecision newDecision = new TruffleInliningDecision(decision.getTarget(), newProfile, decision.getCallSites()); decision = newDecision; } exploredCallSites.add(decision); stack.remove(stack.size() - 1); // pop } for (OptimizedCallTarget target : toRemoveFromCache) { rejectedDecisionsCache.remove(target); } return exploredCallSites; } private static List<OptimizedDirectCallNode> getCallNodes(OptimizedCallTarget target) { final List<OptimizedDirectCallNode> callNodes = new ArrayList<>(); target.getRootNode().accept(new NodeVisitor() { @Override public boolean visit(Node node) { if (node instanceof OptimizedDirectCallNode) { callNodes.add((OptimizedDirectCallNode) node); } return true; } }); return callNodes; } private static TruffleInliningDecision exploreCallSite(List<OptimizedCallTarget> callStack, int callStackNodeCount, TruffleInliningPolicy policy, OptimizedDirectCallNode callNode, int[] visitedNodes, Map<OptimizedCallTarget, TruffleInliningDecision> rejectedDecisionsCache) { OptimizedCallTarget parentTarget = callStack.get(callStack.size() - 2); OptimizedCallTarget currentTarget = callStack.get(callStack.size() - 1); List<TruffleInliningDecision> childCallSites = Collections.emptyList(); double frequency = calculateFrequency(parentTarget, callNode); int nodeCount = callNode.getCurrentCallTarget().getNonTrivialNodeCount(); int recursions = countRecursions(callStack); int deepNodeCount = nodeCount; if (++visitedNodes[0] < (100 * TruffleCompilerOptions.getValue(TruffleInliningMaxCallerSize)) && callStack.size() < 15 && recursions <= TruffleCompilerOptions.getValue(TruffleMaximumRecursiveInlining)) { /* * We make a preliminary optimistic inlining decision with best possible characteristics * to avoid the exploration of unnecessary paths in the inlining tree. */ final CompilerOptions options = callNode.getRootNode().getCompilerOptions(); if (policy.isAllowed(new TruffleInliningProfile(callNode, nodeCount, nodeCount, frequency, recursions), callStackNodeCount, options)) { List<TruffleInliningDecision> exploredCallSites = exploreCallSites(callStack, callStackNodeCount + nodeCount, policy, visitedNodes, rejectedDecisionsCache); childCallSites = decideInlining(exploredCallSites, policy, nodeCount, options); for (TruffleInliningDecision childCallSite : childCallSites) { if (childCallSite.isInline()) { deepNodeCount += childCallSite.getProfile().getDeepNodeCount(); } else { /* we don't need those anymore. */ childCallSite.getCallSites().clear(); } } } } TruffleInliningProfile profile = new TruffleInliningProfile(callNode, nodeCount, deepNodeCount, frequency, recursions); profile.setScore(policy.calculateScore(profile)); return new TruffleInliningDecision(currentTarget, profile, childCallSites); } private static double calculateFrequency(OptimizedCallTarget target, OptimizedDirectCallNode ocn) { return (double) Math.max(1, ocn.getCallCount()) / (double) Math.max(1, target.getCompilationProfile().getInterpreterCallCount()); } private static int countRecursions(List<OptimizedCallTarget> stack) { int count = 0; OptimizedCallTarget top = stack.get(stack.size() - 1); for (int i = 0; i < stack.size() - 1; i++) { if (stack.get(i) == top) { count++; } } return count; } private static List<TruffleInliningDecision> decideInlining(List<TruffleInliningDecision> callSites, TruffleInliningPolicy policy, int nodeCount, CompilerOptions options) { int deepNodeCount = nodeCount; int index = 0; /* First sort the call sites. */ Collections.sort(callSites); for (TruffleInliningDecision callSite : callSites) { TruffleInliningProfile profile = callSite.getProfile(); profile.setQueryIndex(index++); if (policy.isAllowed(profile, deepNodeCount, options)) { callSite.setInline(true); deepNodeCount += profile.getDeepNodeCount(); } } return callSites; } public int getInlinedNodeCount() { int sum = 0; for (TruffleInliningDecision callSite : getCallSites()) { if (callSite.isInline()) { sum += callSite.getProfile().getDeepNodeCount(); } } return sum; } public int countCalls() { int sum = 0; for (TruffleInliningDecision callSite : getCallSites()) { sum += callSite.isInline() ? callSite.countCalls() + 1 : 1; } return sum; } public int countInlinedCalls() { int sum = 0; for (TruffleInliningDecision callSite : getCallSites()) { if (callSite.isInline()) { sum += callSite.countInlinedCalls() + 1; } } return sum; } public final List<TruffleInliningDecision> getCallSites() { return callSites; } @Override public Iterator<TruffleInliningDecision> iterator() { return callSites.iterator(); } public TruffleInliningDecision findByCall(OptimizedDirectCallNode callNode) { for (TruffleInliningDecision d : getCallSites()) { if (d.getProfile().getCallNode() == callNode) { return d; } } return null; } /** * Visits all nodes of the {@link CallTarget} and all of its inlined calls. */ public void accept(OptimizedCallTarget target, NodeVisitor visitor) { target.getRootNode().accept(new CallTreeNodeVisitorImpl(visitor)); } /** * Creates an iterator for all nodes of the {@link CallTarget} and all of its inlined calls. */ public Iterator<Node> makeNodeIterator(OptimizedCallTarget target) { return new CallTreeNodeIterator(target); } /** * This visitor extends the {@link NodeVisitor} interface to be usable for traversing the full * call tree. */ public interface CallTreeNodeVisitor extends NodeVisitor { boolean visit(List<TruffleInlining> decisionStack, Node node); @Override default boolean visit(Node node) { return visit(null, node); } static int getNodeDepth(List<TruffleInlining> decisionStack, Node node) { int depth = calculateNodeDepth(node); if (decisionStack != null) { for (int i = decisionStack.size() - 1; i > 0; i--) { TruffleInliningDecision decision = (TruffleInliningDecision) decisionStack.get(i); depth += calculateNodeDepth(decision.getProfile().getCallNode()); } } return depth; } static int calculateNodeDepth(Node node) { int depth = 0; Node traverseNode = node; while (traverseNode != null) { depth++; traverseNode = traverseNode.getParent(); } return depth; } static TruffleInliningDecision getCurrentInliningDecision(List<TruffleInlining> decisionStack) { if (decisionStack == null || decisionStack.size() <= 1) { return null; } return (TruffleInliningDecision) decisionStack.get(decisionStack.size() - 1); } } /** * This visitor wraps an existing {@link NodeVisitor} or {@link CallTreeNodeVisitor} and * traverses the full Truffle tree including inlined call sites. */ private final class CallTreeNodeVisitorImpl implements NodeVisitor { protected final List<TruffleInlining> stack = new ArrayList<>(); private final NodeVisitor visitor; private boolean continueTraverse = true; CallTreeNodeVisitorImpl(NodeVisitor visitor) { stack.add(TruffleInlining.this); this.visitor = visitor; } @Override public boolean visit(Node node) { if (node instanceof OptimizedDirectCallNode) { OptimizedDirectCallNode callNode = (OptimizedDirectCallNode) node; TruffleInlining inlining = stack.get(stack.size() - 1); if (inlining != null) { TruffleInliningDecision childInlining = inlining.findByCall(callNode); if (childInlining != null) { stack.add(childInlining); continueTraverse = visitNode(node); if (continueTraverse && childInlining.isInline()) { childInlining.getTarget().getRootNode().accept(this); } stack.remove(stack.size() - 1); } } return continueTraverse; } else { continueTraverse = visitNode(node); return continueTraverse; } } private boolean visitNode(Node node) { if (visitor instanceof CallTreeNodeVisitor) { return ((CallTreeNodeVisitor) visitor).visit(stack, node); } else { return visitor.visit(node); } } } private final class CallTreeNodeIterator implements Iterator<Node> { private List<TruffleInlining> inliningDecisionStack = new ArrayList<>(); private List<Iterator<Node>> iteratorStack = new ArrayList<>(); CallTreeNodeIterator(OptimizedCallTarget target) { inliningDecisionStack.add(TruffleInlining.this); iteratorStack.add(NodeUtil.makeRecursiveIterator(target.getRootNode())); } @Override public boolean hasNext() { return peekIterator() != null; } @Override public Node next() { Iterator<Node> iterator = peekIterator(); if (iterator == null) { throw new NoSuchElementException(); } Node node = iterator.next(); if (node instanceof OptimizedDirectCallNode) { visitInlinedCall(node); } return node; } private void visitInlinedCall(Node node) { TruffleInlining currentDecision = inliningDecisionStack.get(inliningDecisionStack.size() - 1); if (currentDecision == null) { return; } TruffleInliningDecision decision = currentDecision.findByCall((OptimizedDirectCallNode) node); if (decision != null && decision.isInline()) { inliningDecisionStack.add(decision); iteratorStack.add(NodeUtil.makeRecursiveIterator(decision.getTarget().getRootNode())); } } private Iterator<Node> peekIterator() { int tos = iteratorStack.size() - 1; while (tos >= 0) { Iterator<Node> iterable = iteratorStack.get(tos); if (iterable.hasNext()) { return iterable; } else { iteratorStack.remove(tos); inliningDecisionStack.remove(tos--); } } return null; } @Override public void remove() { throw new UnsupportedOperationException(); } } }