/*
* IntervalTree.java
*
* Copyright (C) 2014 Leo Osvald <leo.osvald@gmail.com>
*
* This file is part of SGLJ.
*
* SGLJ is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* SGLJ is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this library. If not, see <http://www.gnu.org/licenses/>.
*/
package org.sglj.util.struct;
import java.util.Collection;
import java.util.Comparator;
import java.util.NavigableSet;
import java.util.TreeSet;
public class IntervalTree<E, K extends Comparable<? super K>>
extends AvlTree<E> {
protected final IntervalTraits<E, K> traits;
public IntervalTree(final IntervalTraits<E, K> traits) {
this.traits = traits;
setComparator(traits.getOverlapComparator());
}
@SuppressWarnings("unchecked")
public void findOverlapping(K point, Collection<E> result) {
final E pointInterval = traits.pointInterval(point);
final Comparator<Object> pointComparator = getComparator();
Node<E, K> root = (Node<E, K>)getRoot();
while (root != null) {
int cmp = root.compareKey(pointInterval, pointComparator);
NavigableSet<E> overlappingSet;
if (cmp < 0) {
overlappingSet = root.asc;
root = (Node<E, K>)root.left;
} else if (cmp > 0) {
overlappingSet = root.desc;
root = (Node<E, K>)root.right;
} else { // optimize for full overlap
if (root.asc.size() == 1)
result.add(root.asc.first());
else
result.addAll(root.asc);
break;
}
E last = overlappingSet.floor(pointInterval);
if (last != null) {
if (overlappingSet.first() == last) { // optimize for 1 overlap
result.add(last);
} else {
overlappingSet = overlappingSet.headSet(pointInterval, true);
result.addAll(overlappingSet);
}
}
}
}
@Override
protected AvlNode<E> createNode(E interval) {
return new Node<E, K>(interval, traits);
}
@Override
public boolean contains(Object o) {
@SuppressWarnings("unchecked")
Node<E, K> node = (Node<E, K>)Node.search(
getRoot(), (E)o, getComparator());
if (node == null)
return false;
return node.asc.contains(o);
}
protected static abstract class IntervalTraits<I, K extends Comparable<? super K>> {
protected Comparator<I> overlapComparator;
protected Comparator<I> ascComparator;
protected Comparator<I> descComparator;
public IntervalTraits() {
this(new Comparator<I>() {
@SuppressWarnings("unchecked")
@Override
public int compare(I o1, I o2) {
return ((Comparable<I>)o1).compareTo(o2);
}
});
}
public IntervalTraits(final Comparator<I> idComparator) {
overlapComparator = new Comparator<I>() {
@Override
public int compare(I o1, I o2) {
int cmp = to(o1).compareTo(from(o2));
if (cmp < 0)
return cmp;
cmp = from(o1).compareTo(to(o2));
if (cmp > 0)
return cmp;
return 0;
}
};
ascComparator = new Comparator<I>() {
@Override
public int compare(I o1, I o2) {
int cmp = from(o1).compareTo(from(o2));
if (cmp != 0)
return cmp;
cmp = to(o2).compareTo(to(o1));
if (cmp != 0)
return cmp;
return idComparator.compare(o2, o1);
}
};
descComparator = new Comparator<I>() {
@Override
public int compare(I o1, I o2) {
int cmp = to(o2).compareTo(to(o1));
if (cmp != 0)
return cmp;
cmp = from(o1).compareTo(from(o2));
if (cmp != 0)
return cmp;
return idComparator.compare(o2, o1);
}
};
}
public abstract K from(I interval);
public abstract K to(I interval);
/**
* Creates a point interval <tt>[endpoint, endpoint]</tt> such
* that any smaller than any interval <tt>[endpoint, endpoint]</tt>
* with respect to the {@link Comparable#compareTo(Object)} method.
*
* @param endpoint
* @return
*/
public abstract I pointInterval(K endpoint);
public Comparator<I> getOverlapComparator() {
return overlapComparator;
}
public Comparator<I> getAscendingComparator() {
return ascComparator;
}
public Comparator<I> getDescendingComparator() {
return descComparator;
}
}
protected static class Node<E, K extends Comparable<? super K>>
extends AvlNode<E> {
protected E point;
protected final TreeSet<E> asc;
protected final TreeSet<E> desc;
public Node(E interval, IntervalTraits<E, K> traits) {
this.point = traits.pointInterval(traits.from(interval));
this.asc = new TreeSet<E>(traits.getAscendingComparator());
this.desc = new TreeSet<E>(traits.getDescendingComparator());
}
@Override
public AvlNode<E> mergeCodomains(AvlNode<E> a, AvlNode<E> b) {
if (left != null) {
@SuppressWarnings("unchecked")
Node<E, K> l = (Node<E, K>)left;
moveAll(l.desc, point, l.asc);
}
if (right != null) {
@SuppressWarnings("unchecked")
Node<E, K> r = (Node<E, K>)right;
moveAll(r.asc, point, r.desc);
}
return this;
}
private void moveAll(TreeSet<E> moveSrc, E point, TreeSet<E> src) {
E lastToMove = moveSrc.floor(point);
if (lastToMove == null)
return;
// optimize for a single interval removal
if (lastToMove == moveSrc.first()) {
moveSrc.remove(lastToMove);
src.remove(lastToMove);
asc.add(lastToMove);
desc.add(lastToMove);
return;
}
NavigableSet<E> toMove = moveSrc.headSet(point, true);
src.removeAll(toMove);
asc.addAll(toMove);
desc.addAll(toMove);
toMove.clear();
}
@Override
public void incrementMultiplicity(E key) {
asc.add(key);
desc.add(key);
}
@Override
public boolean decrementMultiplicity(E key) {
asc.remove(key);
desc.remove(key);
return desc.isEmpty();
}
@Override
public E getKey() {
return point;
}
@Override
public String toString() {
return "[" + point + ": " + asc + "]";
}
}
}