/*
* Copyright © 2016 Cask Data, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package co.cask.cdap.etl.planner;
import co.cask.cdap.etl.proto.Connection;
import com.google.common.base.Joiner;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multiset;
import java.util.Collection;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.TreeSet;
import java.util.UUID;
/**
* A DAG (directed acyclic graph) where edges represent a happens-before relationship.
* In these types of scenarios, certain edges may be redundant and can be removed.
* This simplifies the dag into something that is much easier to work with if it needs
* to be used as a fork-join type of dag for workflow execution.
*/
public class ControlDag extends Dag {
private static final Set<String> EMPTY = ImmutableSet.of();
private final Multiset<String> nodeVisits;
public ControlDag(Collection<Connection> connections) {
super(connections);
this.nodeVisits = HashMultiset.create();
}
/**
* Record that this node was visited.
*
* @param node node that was visited
* @return the number of times this node was visited, including the visit from this call
*/
public int visit(String node) {
nodeVisits.add(node);
return nodeVisits.count(node);
}
/**
* Resets the number of times each node was visited back to 0.
*/
public void resetVisitCounts() {
nodeVisits.clear();
}
/**
* Flattens the control dag to remove connections between branches of different forks, which would
* make the dag unusable in pure fork-join workflows.
*
* For example the following dag is not a fork-join dag:
*
* |--> n2 -------|
* | |--> n5
* |--> n3 -------|
* n1--| |
* | v
* |--> n4 --> n6
*
* There are many ways to turn this a fork-join while still respecting all happens-before relationships,
* but for simplicity we'll use an algorithm that doesn't have any nested forks and will turn the above into:
*
* |--> n2 --|
* | | |--> n5 --|
* n1 --|--> n3 --|--> n2.n3.n4 --| |--> n5.n6
* | | |--> n6 --|
* |--> n4 --|
*
* The algorithm is to insert a join node whenever it sees a fork. Every time there is a fork, we will follow
* each branch to its endpoint (a node that forks, merges, or is a sink), then insert a join node that each
* branch endpoint connects to.
*/
public void flatten() {
// this should never be the case, as it should be checked when the dag is created.
if (sources.isEmpty()) {
throw new IllegalStateException("There are no sources in the graph, which means there is a cycle.");
}
trim();
String source;
// if we have multiple sources, insert a fork node as the new source
if (sources.size() > 1) {
// copy to avoid concurrent modification
Set<String> sourcesCopy = new HashSet<>(sources);
String newId = generateJoinNodeName(sourcesCopy);
addNode(newId, EMPTY, sourcesCopy);
source = newId;
} else {
source = sources.iterator().next();
}
flattenFrom(source);
}
private void flattenFrom(String node) {
Set<String> outputs = outgoingConnections.get(node);
if (outputs.isEmpty()) {
return;
}
if (outputs.size() == 1) {
flattenFrom(outputs.iterator().next());
return;
}
Multimap<String, String> branchEndpointOutputs = HashMultimap.create();
// can't just use branchEndpointOutputs.keySet(),
// because that won't track branch endpoints that had no output (sinks)
Set<String> branchEndpoints = new HashSet<>();
for (String output : outputs) {
String branchEndpoint = findBranchEnd(output);
branchEndpoints.add(branchEndpoint);
branchEndpointOutputs.putAll(branchEndpoint, outgoingConnections.get(branchEndpoint));
}
// if all the branch endpoints connect to a single node, there is no need to add a join node
Set<String> endpointOutputs = new HashSet<>(branchEndpointOutputs.values());
if (endpointOutputs.size() == 1) {
flattenFrom(endpointOutputs.iterator().next());
return;
}
// add a connection from each branch endpoint to a newly added join node
// then move all outgoing connections from each branch endpoint so that they are coming out of the new join node
String newJoinNode = generateJoinNodeName(branchEndpoints);
addNode(newJoinNode, branchEndpoints, endpointOutputs);
// remove the outgoing connections from endpoints that aren't going to our new join node
for (Map.Entry<String, String> endpointEntry : branchEndpointOutputs.entries()) {
removeConnection(endpointEntry.getKey(), endpointEntry.getValue());
}
/*
have to trim again due to reshuffling of nodes. For example, if we have:
|--> n3
|--> n2 --|
| |--> n4
n1 --| |
| v
|--> n5 -----> n6
after we insert the new join node we'll have:
|--> n2 --| |--> n3
| | |
n1 --| |--> join --|--> n4
| | | |
|--> n5 --| | v
|--> n6
and we need to remove the connection from join -> n6, otherwise the algorithm will get messed up
*/
trim();
// then keep flattening from the new join node
flattenFrom(newJoinNode);
}
// go down a branch until we find a node with multiple outputs, a node with multiple inputs, or a sink
private String findBranchEnd(String node) {
Set<String> outputs = outgoingConnections.get(node);
// if this is a sink, or if this is a fork on a branch
if (outputs.isEmpty() || outputs.size() > 1) {
return node;
}
// if the next node is a join node
String output = outputs.iterator().next();
if (incomingConnections.get(output).size() > 1) {
return node;
}
// otherwise keep going down this branch
return findBranchEnd(output);
}
/**
* Returns the number of paths from the start node to the stop node.
* The number of paths from a node to itself is 1.
*
* @param start the node to start from
* @param stop the node to end at
* @return the number of paths from the start node to the stop node
*/
private int numPaths(String start, String stop) {
if (start.equals(stop)) {
return 1;
}
int count = 0;
for (String output : getNodeOutputs(start)) {
count += numPaths(output, stop);
}
return count;
}
/**
* Trims any redundant control connections.
*
* For example:
* n1 ------> n2
* | |
* | v
* |----> n3
* has a redundant edge n1 -> n3, because the edge from n2 -> n3 already enforces n1 -> n3.
* The approach is look at each node (call it nodeB). For each input into nodeB (call it nodeA),
* if there is another path from nodeA to nodeB besides the direct edge, we can remove the edge nodeA -> nodeB.
*
* @return number of connections removed.
*/
public int trim() {
int numRemoved = 0;
for (String node : nodes) {
Set<Connection> toRemove = new HashSet<>();
for (String nodeInput : getNodeInputs(node)) {
if (numPaths(nodeInput, node) > 1) {
toRemove.add(new Connection(nodeInput, node));
}
}
for (Connection conn : toRemove) {
removeConnection(conn.getFrom(), conn.getTo());
}
numRemoved += toRemove.size();
}
return numRemoved;
}
/**
* Add a node with the following outputs and inputs
*/
private void addNode(String node, Collection<String> inputs, Collection<String> outputs) {
nodes.add(node);
for (String output : outputs) {
outgoingConnections.put(node, output);
incomingConnections.put(output, node);
sources.remove(output);
}
for (String input : inputs) {
incomingConnections.put(node, input);
outgoingConnections.put(input, node);
sinks.remove(input);
}
if (outputs.isEmpty()) {
sinks.add(node);
}
if (inputs.isEmpty()) {
sources.add(node);
}
}
private String generateJoinNodeName(Set<String> inputs) {
// using sorted sets to guarantee the name is deterministic
String name = Joiner.on('.').join(new TreeSet<>(inputs));
if (nodes.contains(name)) {
name += UUID.randomUUID().toString();
}
return name;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
if (!super.equals(o)) {
return false;
}
ControlDag that = (ControlDag) o;
return Objects.equals(nodeVisits, that.nodeVisits);
}
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), nodeVisits);
}
@Override
public String toString() {
return "ControlDag{" +
"nodeVisits=" + nodeVisits +
"} " + super.toString();
}
}