package org.infinispan.objectfilter.impl.util; import java.util.ArrayList; import java.util.List; /** * An Interval tree is an ordered tree data structure to hold Intervals. Specifically, it allows one to efficiently find * all Intervals that contain any given value in O(log n) time (see http://en.wikipedia.org/wiki/Interval_tree). * <p/> * The implementation is based on red-black trees (http://en.wikipedia.org/wiki/Red–black_tree). Additions and removals * are efficient and require only minimal rebalancing of the tree as opposed to other implementation approaches that * perform a full rebuild after insertion. Duplicate intervals are not stored but are coped for. * * @author anistor@redhat.com * @since 7.0 */ public final class IntervalTree<K extends Comparable<K>, V> { public static final class Node<K extends Comparable<K>, V> { /** * The interval. The low value is the key of this node within the search tree. */ public final Interval<K> interval; //todo maybe it's wise to make it private and expose getter /** * A user payload value. */ public V value; //todo maybe it's wise to make it private and expose getter and setter /** * The maximum value of any Interval endpoint stored in the subtree rooted at this node. */ private K max; /** * The parent node. */ private Node<K, V> parent; /** * The left child. */ private Node<K, V> left; /** * The right child. */ private Node<K, V> right; /** * Indicates the color of this node (either red or black). */ private boolean isRed = false; private Node(Interval<K> interval) { this.interval = interval; this.max = interval.up; } private Node() { interval = null; } } private final Node<K, V> sentinel; /** * The root of the tree. */ private Node<K, V> root; public IntervalTree() { sentinel = new Node<K, V>(); sentinel.left = sentinel; sentinel.right = sentinel; sentinel.parent = sentinel; root = sentinel; } private int compare(K k1, K k2) { if (k1 == Interval.getMinusInf() || k2 == Interval.getPlusInf()) return -1; if (k1 == Interval.getPlusInf() || k2 == Interval.getMinusInf()) return 1; return k1.compareTo(k2); } private K max(K k1, K k2) { return compare(k1, k2) >= 0 ? k1 : k2; } private boolean compareLowerBound(Interval<K> i1, Interval<K> i2) { int res = compare(i1.low, i2.low); return res < 0 || res == 0 && (i1.includeLower || !i2.includeUpper); } /** * Compare two Intervals. * * @return a negative integer, zero, or a positive integer depending if Interval i1 is to the left of i2, overlaps * with it, or is to the right of i2. */ private int compareIntervals(Interval<K> i1, Interval<K> i2) { int res1 = compare(i1.up, i2.low); if (res1 < 0 || res1 <= 0 && (!i1.includeUpper || !i2.includeLower)) { return -1; } int res2 = compare(i2.up, i1.low); if (res2 < 0 || res2 <= 0 && (!i2.includeUpper || !i1.includeLower)) { return 1; } return 0; } private void checkValidInterval(Interval<K> interval) { if (interval == null) { throw new IllegalArgumentException("Interval cannot be null"); } if (compare(interval.low, interval.up) > 0) { throw new IllegalArgumentException("Interval lower bound cannot be higher than the upper bound"); } } /** * Add the {@code Interval} into this {@code IntervalTree} and return the Node. Possible duplicates are found and the * existing Node is returned instead of adding a new one. * * @param i an Interval to be inserted */ public Node<K, V> add(Interval<K> i) { checkValidInterval(i); return add(new Node<K, V>(i)); } private Node<K, V> add(Node<K, V> n) { n.left = n.right = sentinel; Node<K, V> y = root; Node<K, V> x = root != null ? root.left : null; while (x != sentinel) { y = x; if (x.interval.equals(n.interval)) { return x; } if (compareLowerBound(n.interval, y.interval)) { x = x.left; } else { x = x.right; } y.max = max(n.max, y.max); if (y.parent == root) { root.max = y.max; } } n.parent = y; if (root != null && y == root) { root.max = n.max; } if (y != null) { if (y == root || compareLowerBound(n.interval, y.interval)) { y.left = n; } else { y.right = n; } } rebalanceAfterAdd(n); return n; } private void rebalanceAfterAdd(Node<K, V> z) { z.isRed = true; while (z.parent.isRed) { if (z.parent == z.parent.parent.left) { Node<K, V> y = z.parent.parent.right; if (y.isRed) { z.parent.isRed = false; y.isRed = false; z.parent.parent.isRed = true; z = z.parent.parent; } else { if (z == z.parent.right) { z = z.parent; rotateLeft(z); } z.parent.isRed = false; z.parent.parent.isRed = true; rotateRight(z.parent.parent); } } else { Node<K, V> y = z.parent.parent.left; if (y.isRed) { z.parent.isRed = false; y.isRed = false; z.parent.parent.isRed = true; z = z.parent.parent; } else { if (z == z.parent.left) { z = z.parent; rotateRight(z); } z.parent.isRed = false; z.parent.parent.isRed = true; rotateLeft(z.parent.parent); } } } root.left.isRed = false; } private void rotateLeft(Node<K, V> x) { Node<K, V> y = x.right; x.right = y.left; if (y.left != sentinel) y.left.parent = x; y.parent = x.parent; if (x == x.parent.left) { x.parent.left = y; } else { x.parent.right = y; } y.left = x; x.parent = y; if (y.parent == root) { root.max = x.max; } y.max = x.max; x.max = max(x.interval.up, max(x.left.max, x.right.max)); } private void rotateRight(Node<K, V> x) { Node<K, V> y = x.left; x.left = y.right; if (y.right != sentinel) { y.right.parent = x; } y.parent = x.parent; if (x == x.parent.left) { x.parent.left = y; } else { x.parent.right = y; } y.right = x; x.parent = y; if (y.parent == root) { root.max = x.max; } y.max = x.max; x.max = max(x.interval.up, max(x.left.max, x.right.max)); } /** * Removes the Interval. * * @param i the interval to remove */ public boolean remove(Interval<K> i) { checkValidInterval(i); return remove(root.left, i); } private boolean remove(Node<K, V> n, Interval<K> i) { if (n == sentinel || compare(i.low, n.max) > 0) { return false; } if (n.interval.equals(i)) { remove(n); return true; } if (n.left != sentinel && remove(n.left, i)) { return true; } if (compareIntervals(i, n.interval) < 0) { return false; } return n.right != sentinel && remove(n.right, i); } public void remove(Node<K, V> n) { n.max = Interval.<K>getMinusInf(); for (Node<K, V> i = n.parent; i != root; i = i.parent) { i.max = max(i.left.max, i.right.max); if (i.parent == root) { root.max = i.max; } } Node<K, V> y; Node<K, V> x; if (n.left == sentinel || n.right == sentinel) { y = n; } else { y = findSuccessor(n); } if (y.left == sentinel) { x = y.right; } else { x = y.left; } x.parent = y.parent; if (root == x.parent) { root.left = x; } else if (y == y.parent.left) { y.parent.left = x; } else { y.parent.right = x; } if (y != n) { if (!y.isRed) { rebalanceAfterRemove(x); } y.left = n.left; y.right = n.right; y.parent = n.parent; y.isRed = n.isRed; n.left.parent = n.right.parent = y; if (n == n.parent.left) { n.parent.left = y; } else { n.parent.right = y; } } else if (!y.isRed) { rebalanceAfterRemove(x); } } private Node<K, V> findSuccessor(Node<K, V> x) { Node<K, V> successor = x.right; if (successor != sentinel) { while (successor.left != sentinel) { successor = successor.left; } return successor; } successor = x.parent; while (x == successor.right) { x = successor; successor = successor.parent; } if (successor == root) { return sentinel; } return successor; } private void rebalanceAfterRemove(Node<K, V> x) { while (x != root.left && !x.isRed) { if (x == x.parent.left) { Node<K, V> w = x.parent.right; if (w.isRed) { w.isRed = false; x.parent.isRed = true; rotateLeft(x.parent); w = x.parent.right; } if (!w.left.isRed && !w.right.isRed) { w.isRed = true; x = x.parent; } else { if (!w.right.isRed) { w.left.isRed = false; w.isRed = true; rotateRight(w); w = x.parent.right; } w.isRed = x.parent.isRed; x.parent.isRed = false; w.right.isRed = false; rotateLeft(x.parent); x = root.left; } } else { Node<K, V> w = x.parent.left; if (w.isRed) { w.isRed = false; x.parent.isRed = true; rotateRight(x.parent); w = x.parent.left; } if (!w.right.isRed && !w.left.isRed) { w.isRed = true; x = x.parent; } else { if (!w.left.isRed) { w.right.isRed = false; w.isRed = true; rotateLeft(w); w = x.parent.left; } w.isRed = x.parent.isRed; x.parent.isRed = false; w.left.isRed = false; rotateRight(x.parent); x = root.left; } } } x.isRed = false; } /** * Checks if this {@code IntervalTree} does not have any Intervals. * * @return {@code true} if this {@code IntervalTree} is empty, {@code false} otherwise. */ public boolean isEmpty() { return root.left == sentinel; } /** * Find all Intervals that contain a given value. * * @param k the value to search for * @return a non-null List of intervals that contain the value */ public List<Node<K, V>> stab(K k) { Interval<K> i = new Interval<K>(k, true, k, true); final List<Node<K, V>> nodes = new ArrayList<Node<K, V>>(); findOverlap(root.left, i, new NodeCallback<K, V>() { @Override public void handle(Node<K, V> node) { nodes.add(node); } }); return nodes; } public void stab(K k, NodeCallback<K, V> nodeCallback) { Interval<K> i = new Interval<K>(k, true, k, true); findOverlap(root.left, i, nodeCallback); } private void findOverlap(Node<K, V> n, Interval<K> i, NodeCallback<K, V> nodeCallback) { if (n == sentinel || compare(i.low, n.max) > 0) { return; } if (n.left != sentinel) { findOverlap(n.left, i, nodeCallback); } if (compareIntervals(n.interval, i) == 0) { nodeCallback.handle(n); } if (compareIntervals(i, n.interval) < 0) { return; } if (n.right != sentinel) { findOverlap(n.right, i, nodeCallback); } } public Node<K, V> findNode(Interval<K> i) { checkValidInterval(i); return findNode(root.left, i); } private Node<K, V> findNode(Node<K, V> n, Interval<K> i) { if (n == sentinel || compare(i.low, n.max) > 0) { return null; } if (n.interval.equals(i)) { return n; } if (n.left != sentinel) { Node<K, V> w = findNode(n.left, i); if (w != null) { return w; } } if (compareIntervals(i, n.interval) < 0) { return null; } if (n.right != sentinel) { return findNode(n.right, i); } return null; } public interface NodeCallback<K extends Comparable<K>, V> { void handle(Node<K, V> node); } public void inorderTraversal(NodeCallback<K, V> nodeCallback) { inorderTraversal(root.left, nodeCallback); } private void inorderTraversal(Node<K, V> n, NodeCallback<K, V> nodeCallback) { if (n != sentinel) { inorderTraversal(n.left, nodeCallback); nodeCallback.handle(n); inorderTraversal(n.right, nodeCallback); } } @Override public String toString() { final StringBuilder sb = new StringBuilder(); inorderTraversal(new NodeCallback<K, V>() { @Override public void handle(Node<K, V> n) { if (sb.length() > 0) { sb.append(", "); } sb.append(n.interval); sb.append("->{"); sb.append(n.value); sb.append('}'); } }); return sb.toString(); } }