/*********************************************************************************************************************** * * Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu) * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. * **********************************************************************************************************************/ package eu.stratosphere.util.dag; import java.io.Serializable; import java.lang.reflect.Array; import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import eu.stratosphere.util.IdentitySet; /** * Skeleton implementation of {@link SubGraph}. * * @param <Node> * the type of all node * @param <InputNode> * the type of all input nodes * @param <OutputNode> * the type of all output nodes */ public abstract class GraphModule<Node, InputNode extends Node, OutputNode extends Node> implements SubGraph<Node, InputNode, OutputNode>, Serializable { /** * */ private static final long serialVersionUID = -8802006043539156002L; /** * The outputs of the module. */ protected final OutputNode[] outputNodes; /** * internal outputs */ protected final List<OutputNode> internalOutputNodes = new ArrayList<OutputNode>(); /** * The inputs of the module. */ protected final InputNode[] inputNodes; private final ConnectionNavigator<Node> navigator; private final String name; /** * Initializes a PactModule having the given inputs, outputs, and {@link Navigator}. * * @param inputNodes * the inputs * @param outputNodes * the outputs * @param navigator * the navigator used to traverse the graph of nodes */ protected GraphModule(final String name, final InputNode[] inputNodes, final OutputNode[] outputNodes, final ConnectionNavigator<Node> navigator) { this.name = name; this.inputNodes = inputNodes; this.outputNodes = outputNodes; this.navigator = navigator; } @Override public void addInternalOutput(final OutputNode output) { this.internalOutputNodes.add(output); } /** * Returns the internalOutputNodes. * * @return the internalOutputNodes */ @SuppressWarnings("unchecked") public OutputNode[] getInternalOutputNodes() { return this.internalOutputNodes.toArray((OutputNode[]) Array.newInstance( this.outputNodes.getClass().getComponentType(), this.internalOutputNodes.size())); } @SuppressWarnings("unchecked") @Override public OutputNode[] getAllOutputs() { if (this.internalOutputNodes.isEmpty()) return this.outputNodes; final OutputNode[] allOutputs = (OutputNode[]) Array.newInstance(this.outputNodes.getClass() .getComponentType(), this.outputNodes.length + this.internalOutputNodes.size()); System.arraycopy(this.outputNodes, 0, allOutputs, 0, this.outputNodes.length); for (int index = 0; index < allOutputs.length; index++) allOutputs[index + this.outputNodes.length] = this.internalOutputNodes.get(index); return allOutputs; } @Override public InputNode getInput(final int index) { return this.inputNodes[index]; } @Override public void setInput(final int index, final InputNode input) { this.inputNodes[index] = input; } @Override public InputNode[] getInputs() { return this.inputNodes; } public String getName() { return this.name; } @Override public OutputNode getOutput(final int index) { return this.outputNodes[index]; } @Override public OutputNode getInternalOutputNodes(final int index) { return this.internalOutputNodes.get(index); } @Override public void setOutput(final int index, final OutputNode output) { this.outputNodes[index] = output; } @Override public OutputNode[] getOutputs() { return this.outputNodes; } @Override public Iterable<? extends Node> getReachableNodes() { return OneTimeTraverser.INSTANCE.getReachableNodes(this.getAllOutputs(), this.navigator); } @Override public String toString() { final GraphPrinter<Node> dagPrinter = new GraphPrinter<Node>(); dagPrinter.setWidth(80); return dagPrinter.toString(this.getAllOutputs(), this.navigator); } @Override public void validate() { for (final OutputNode output : this.getAllOutputs()) for (final Node node : this.navigator.getConnectedNodes(output)) if (node == null) throw new IllegalStateException(String.format("%s: output %s is not fully connected", this.getName(), output)); final Iterable<? extends Node> reachableNodes = this.getReachableNodes(); final List<InputNode> inputList = new LinkedList<InputNode>(Arrays.asList(this.inputNodes)); for (final Node node : reachableNodes) inputList.remove(node); if (!inputList.isEmpty()) throw new IllegalStateException( String.format("%s: inputs %s are not fully connected", this.getName(), inputList)); } @Override public int hashCode() { final int prime = 31; int result = 1; for (final Node node : this.getReachableNodes()) result = prime * result + node.hashCode(); return result; } @Override public boolean equals(final Object obj) { if (this == obj) return true; if (obj == null) return false; if (this.getClass() != obj.getClass()) return false; @SuppressWarnings({ "rawtypes", "unchecked" }) final GraphModule<Node, InputNode, OutputNode> other = (GraphModule) obj; return this.getUnmatchingNodes(other) == null; } public List<Node> getUnmatchingNodes(final GraphModule<Node, InputNode, OutputNode> other) { final IdentitySet<Node> seen = new IdentitySet<Node>(); return this.getUnmatchingNode(Arrays.asList(this.getAllOutputs()), Arrays.asList(other.getAllOutputs()), seen); } /** * @param allOutputs * @param allOutputs2 * @param seen * @return */ @SuppressWarnings("unchecked") private List<Node> getUnmatchingNode(final Iterable<? extends Node> nodes1, final Iterable<? extends Node> nodes2, final IdentitySet<Node> seen) { final Iterator<? extends Node> iterator1 = nodes1.iterator(); final Iterator<? extends Node> iterator2 = nodes2.iterator(); while (iterator1.hasNext() && iterator2.hasNext()) { final Node node1 = iterator1.next(); final Node node2 = iterator2.next(); if (!node1.equals(node2)) return Arrays.asList(node1, node2); final List<Node> unmatching = this.getUnmatchingNode(this.navigator.getConnectedNodes(node1), this.navigator.getConnectedNodes(node2), seen); if (unmatching != null) return unmatching; } return null; } }