/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.sql.planner.Symbol; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Sets; import javax.annotation.concurrent.Immutable; import java.util.Collection; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toSet; @Immutable public class GroupIdNode extends PlanNode { private final PlanNode source; // in terms of output symbols private final List<List<Symbol>> groupingSets; // from output to input symbols private final Map<Symbol, Symbol> groupingSetMappings; private final Map<Symbol, Symbol> argumentMappings; private final Symbol groupIdSymbol; @JsonCreator public GroupIdNode(@JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, @JsonProperty("groupingSets") List<List<Symbol>> groupingSets, @JsonProperty("groupingSetMappings") Map<Symbol, Symbol> groupingSetMappings, @JsonProperty("argumentMappings") Map<Symbol, Symbol> argumentMappings, @JsonProperty("groupIdSymbol") Symbol groupIdSymbol) { super(id); this.source = requireNonNull(source); this.groupingSets = ImmutableList.copyOf(requireNonNull(groupingSets)); this.groupingSetMappings = ImmutableMap.copyOf(requireNonNull(groupingSetMappings)); this.argumentMappings = ImmutableMap.copyOf(requireNonNull(argumentMappings)); this.groupIdSymbol = requireNonNull(groupIdSymbol); checkArgument(Sets.intersection(groupingSetMappings.keySet(), argumentMappings.keySet()).isEmpty(), "argument outputs and grouping outputs must be a disjoint set"); } @Override public List<Symbol> getOutputSymbols() { return ImmutableList.<Symbol>builder() .addAll(groupingSets.stream() .flatMap(Collection::stream) .collect(toSet())) .addAll(argumentMappings.keySet()) .add(groupIdSymbol) .build(); } @Override public List<PlanNode> getSources() { return ImmutableList.of(source); } @JsonProperty public PlanNode getSource() { return source; } @JsonProperty public List<List<Symbol>> getGroupingSets() { return groupingSets; } @JsonProperty public Map<Symbol, Symbol> getGroupingSetMappings() { return groupingSetMappings; } @JsonProperty public Map<Symbol, Symbol> getArgumentMappings() { return argumentMappings; } @JsonProperty public Symbol getGroupIdSymbol() { return groupIdSymbol; } @Override public <C, R> R accept(PlanVisitor<C, R> visitor, C context) { return visitor.visitGroupId(this, context); } public Set<Symbol> getInputSymbols() { return ImmutableSet.<Symbol>builder() .addAll(argumentMappings.values()) .addAll(groupingSets.stream() .map(set -> set.stream() .map(groupingSetMappings::get).collect(Collectors.toList())) .flatMap(Collection::stream) .collect(toSet())) .build(); } // returns the common grouping columns in terms of output symbols public Set<Symbol> getCommonGroupingColumns() { Set<Symbol> intersection = new HashSet<>(groupingSets.get(0)); for (int i = 1; i < groupingSets.size(); i++) { intersection.retainAll(groupingSets.get(i)); } return ImmutableSet.copyOf(intersection); } @Override public PlanNode replaceChildren(List<PlanNode> newChildren) { return new GroupIdNode(getId(), Iterables.getOnlyElement(newChildren), groupingSets, groupingSetMappings, argumentMappings, groupIdSymbol); } }