package gov.nasa.jpl.mbee.mdk.generator.graphs;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
public class DirectedHyperGraphHashSet<VertexType, EdgeType extends DirectedHyperEdge<VertexType>> extends
UndirectedHyperGraphHashSet<VertexType, EdgeType> implements DirectedHyperGraph<VertexType, EdgeType> {
protected Map<VertexType, Set<EdgeType>> Vs2E;
protected Map<VertexType, Set<EdgeType>> Vt2E;
public DirectedHyperGraphHashSet() {
super();
Vs2E = new HashMap<VertexType, Set<EdgeType>>();
Vt2E = new HashMap<VertexType, Set<EdgeType>>();
}
@SuppressWarnings("unchecked")
@Override
public boolean addEdge(Set<VertexType> sourceVertices, Set<VertexType> targetVertices) {
DirectedHyperEdge<VertexType> edge = new DirectedHyperEdgeVector<VertexType>(sourceVertices,
targetVertices);
return addEdge((EdgeType) edge);
}
@Override
public boolean addVertex(VertexType vertex) {
boolean added = super.addVertex(vertex);
if (added) {
Vs2E.put(vertex, new HashSet<EdgeType>());
Vt2E.put(vertex, new HashSet<EdgeType>());
}
return added;
}
@Override
public boolean removeVertex(VertexType vertex) {
boolean removed = super.removeVertex(vertex);
if (removed) {
Vs2E.remove(vertex);
Vt2E.remove(vertex);
}
return removed;
}
@Override
public void clear() {
super.clear();
Vs2E.clear();
Vt2E.clear();
}
@Override
public boolean addEdge(EdgeType edge) {
boolean added = super.addEdge(edge);
if (added) {
for (VertexType v : edge.getSourceVertices()) {
if (Vs2E.get(v) == null) {
Vs2E.put(v, new HashSet<EdgeType>());
}
if (Vt2E.get(v) == null) {
Vt2E.put(v, new HashSet<EdgeType>());
}
Vs2E.get(v).add(edge);
}
for (VertexType v : edge.getTargetVertices()) {
if (Vs2E.get(v) == null) {
Vs2E.put(v, new HashSet<EdgeType>());
}
if (Vt2E.get(v) == null) {
Vt2E.put(v, new HashSet<EdgeType>());
}
Vt2E.get(v).add(edge);
}
}
return added;
}
@Override
public boolean removeEdge(EdgeType edge) {
boolean removed = super.removeEdge(edge);
if (removed) {
for (VertexType v : edge.getSourceVertices()) {
assert (Vs2E.get(v) != null);
Vs2E.get(v).remove(edge);
}
for (VertexType v : edge.getTargetVertices()) {
assert (Vt2E.get(v) != null);
Vt2E.get(v).remove(edge);
}
}
return removed;
}
@Override
public Set<VertexType> findNeighborsOf(VertexType vertex) {
Set<VertexType> vertices = new HashSet<VertexType>();
for (EdgeType e : findEdgesWithSourceVertex(vertex)) {
vertices.addAll(e.getTargetVertices());
}
return vertices;
}
@Override
public Set<VertexType> findChildrenOf(VertexType vertex) {
Set<VertexType> vertices = new HashSet<VertexType>();
for (EdgeType edge : Vs2E.get(vertex)) {
vertices.addAll(edge.getTargetVertices());
}
return vertices;
}
@Override
public Set<EdgeType> findEdgesWithSourceVertex(VertexType vertex) {
return Vs2E.get(vertex);
}
@Override
public Set<EdgeType> findEdgesWithSourceVertices(Set<VertexType> vertices) {
Set<EdgeType> edges = new HashSet<EdgeType>();
if (vertices.size() == 0) {
return edges;
}
boolean isFirstPass = true;
for (VertexType v : vertices) {
if (isFirstPass) {
edges.addAll(findEdgesWithSourceVertex(v));
isFirstPass = false;
}
else {
edges.retainAll(findEdgesWithSourceVertex(v));
}
if (edges.size() == 0) {
break;
}
}
return edges;
}
@Override
public Set<EdgeType> findEdgesWithTargetVertex(VertexType vertex) {
return Vt2E.get(vertex);
}
@Override
public Set<EdgeType> findEdgesWithTargetVertices(Set<VertexType> vertices) {
Set<EdgeType> edges = new HashSet<EdgeType>();
if (vertices.size() == 0) {
return edges;
}
boolean isFirstPass = true;
for (VertexType v : vertices) {
if (isFirstPass) {
edges.addAll(findEdgesWithTargetVertex(v));
isFirstPass = false;
}
else {
edges.retainAll(findEdgesWithTargetVertex(v));
}
if (edges.size() == 0) {
break;
}
}
return edges;
}
}