package trees.lockbased; import java.util.AbstractMap; import java.util.Comparator; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.locks.ReentrantLock; import contention.abstractions.CompositionalMap; import contention.abstractions.CompositionalMap.Vars; import contention.abstractions.MaintenanceAlg; /** * The contention-friendly tree implementation of map * as described in: * * T. Crain, V. Gramoli and M. Ryanla. * A Contention-Friendly Binary Search Tree. * Euro-Par 2013. * * @author Tyler Crain * * @param <K> * @param <V> */ public class LockBasedFriendlyTreeMap<K, V> extends AbstractMap<K, V> implements CompositionalMap<K, V>, MaintenanceAlg { static final boolean useFairLocks = false; static final boolean allocateOutside = true; // we encode directions as characters static final char Left = 'L'; static final char Right = 'R'; final V DELETED = (V) new Object(); private class MaintenanceThread extends Thread { LockBasedFriendlyTreeMap<K, V> map; MaintenanceThread(LockBasedFriendlyTreeMap<K, V> map) { this.map = map; } public void run() { map.doMaintenance(); } } private class MaintVariables { long propogations = 0, rotations = 0; } private final MaintVariables vars = new MaintVariables(); private static class Node<K, V> { K key; class BalanceVars { volatile int localh, lefth, righth; } final BalanceVars bal = new BalanceVars(); volatile V value; volatile Node<K, V> left; volatile Node<K, V> right; final ReentrantLock lock; volatile boolean removed; Node(final K key, final V value) { this.key = key; this.value = value; this.removed = false; this.lock = new ReentrantLock(useFairLocks); this.right = null; this.left = null; this.bal.localh = 1; this.bal.righth = 0; this.bal.lefth = 0; } Node(final K key, final int localh, final int lefth, final int righth, final V value, final Node<K, V> left, final Node<K, V> right) { this.key = key; this.bal.localh = localh; this.bal.righth = righth; this.bal.lefth = lefth; this.value = value; this.left = left; this.right = right; this.lock = new ReentrantLock(useFairLocks); this.removed = false; } void setupNode(final K key, final int localh, final int lefth, final int righth, final V value, final Node<K, V> left, final Node<K, V> right) { this.key = key; this.bal.localh = localh; this.bal.righth = righth; this.bal.lefth = lefth; this.value = value; this.left = left; this.right = right; this.removed = false; } Node<K, V> child(char dir) { return dir == Left ? left : right; } Node<K, V> childSibling(char dir) { return dir == Left ? right : left; } void setChild(char dir, Node<K, V> node) { if (dir == Left) { left = node; } else { right = node; } } void updateLocalh() { this.bal.localh = Math.max(this.bal.lefth + 1, this.bal.righth + 1); } } // state private final Node<K, V> root = new Node<K, V>(null, null); private Comparator<? super K> comparator; volatile boolean stop = false; private MaintenanceThread mainThd; // used in the getSize function int size; private long structMods = 0; // Constructors public LockBasedFriendlyTreeMap() { // temporary this.startMaintenance(); } public LockBasedFriendlyTreeMap(final Comparator<? super K> comparator) { // temporary this.startMaintenance(); this.comparator = comparator; } // What is this? private Comparable<? super K> comparable(final Object key) { if (key == null) { throw new NullPointerException(); } if (comparator == null) { return (Comparable<? super K>) key; } return new Comparable<K>() { final Comparator<? super K> _cmp = comparator; @SuppressWarnings("unchecked") public int compareTo(final K rhs) { return _cmp.compare((K) key, rhs); } }; } @Override public boolean containsKey(Object key) { if (get(key) == null) { return false; } return true; } public boolean contains(Object key) { if (get(key) == null) { return false; } return true; } void finishCount(int nodesTraversed) { Vars vars = counts.get(); vars.getCount++; vars.nodesTraversed += nodesTraversed; } @Override public V get(final Object key) { Node<K, V> next, current; next = root; final Comparable<? super K> k = comparable(key); int rightCmp; int nodesTraversed = 0; while (true) { current = next; if (current.key == null) { rightCmp = -100; } else { rightCmp = k.compareTo(current.key); } if (rightCmp == 0) { V value = current.value; if (value == DELETED) { if (TRAVERSAL_COUNT) { finishCount(nodesTraversed); } return null; } if (TRAVERSAL_COUNT) { finishCount(nodesTraversed); } return value; } if (rightCmp <= 0) { next = current.left; } else { next = current.right; } if (TRAVERSAL_COUNT) { nodesTraversed++; } if (next == null) { if (TRAVERSAL_COUNT) { finishCount(nodesTraversed); } return null; } } } @Override public V remove(final Object key) { Node<K, V> next, current; next = root; final Comparable<? super K> k = comparable(key); int rightCmp; V value; while (true) { current = next; if (current.key == null) { rightCmp = -100; } else { rightCmp = k.compareTo(current.key); } if (rightCmp == 0) { if (current.value == DELETED) { return null; } current.lock.lock(); if (!current.removed) { break; } else { current.lock.unlock(); } } if (rightCmp <= 0) { next = current.left; } else { next = current.right; } if (next == null) { if (rightCmp != 0) { return null; } // this only happens if node is removed, so you take the // opposite path // this should never be null System.out.println("Going right"); next = current.right; } } value = current.value; if (value == DELETED) { current.lock.unlock(); return null; } else { current.value = DELETED; current.lock.unlock(); // System.out.println("delete"); return value; } } @Override public V putIfAbsent(K key, V value) { int rightCmp; Node<K, V> next, current; next = root; final Comparable<? super K> k = comparable(key); Node<K, V> n = null; // int traversed = 0; V val; while (true) { current = next; // traversed++; if (current.key == null) { rightCmp = -100; } else { rightCmp = k.compareTo(current.key); } if (rightCmp == 0) { val = current.value; if (val != DELETED) { // System.out.println(traversed); return val; } current.lock.lock(); if (!current.removed) { break; } else { current.lock.unlock(); } } if (rightCmp <= 0) { next = current.left; } else { next = current.right; } if (next == null) { if (n == null && allocateOutside) { n = new Node<K, V>(key, value); } current.lock.lock(); if (!current.removed) { if (rightCmp <= 0) { next = current.left; } else { next = current.right; } if (next == null) { break; } else { current.lock.unlock(); } } else { current.lock.unlock(); // maybe have to check if the other one is still null before // going the opposite way? // YES!! We do this! if (rightCmp <= 0) { next = current.left; } else { next = current.right; } if (next == null) { if (rightCmp > 0) { next = current.left; } else { next = current.right; } } } } } val = current.value; if (rightCmp == 0) { if (val == DELETED) { current.value = value; current.lock.unlock(); // System.out.println("insert"); // System.out.println(traversed); return null; } else { current.lock.unlock(); return val; } } else { if (!allocateOutside) { n = new Node<K, V>(key, value); } if (rightCmp <= 0) { current.left = n; } else { current.right = n; } current.lock.unlock(); // System.out.println(traversed); // System.out.println("insert"); return null; } } @Override public Set<java.util.Map.Entry<K, V>> entrySet() { // TODO Auto-generated method stub return null; } // maintenance boolean removeNode(Node<K, V> parent, char direction) { Node<K, V> n, child; // can get before locks because only maintenance removes nodes if (parent.removed) return false; n = direction == Left ? parent.left : parent.right; if (n == null) return false; // get the locks n.lock.lock(); parent.lock.lock(); if (n.value != DELETED) { n.lock.unlock(); parent.lock.unlock(); return false; } if ((child = n.left) != null) { if (n.right != null) { n.lock.unlock(); parent.lock.unlock(); return false; } } else { child = n.right; } if (direction == Left) { parent.left = child; } else { parent.right = child; } n.left = parent; n.right = parent; n.removed = true; n.lock.unlock(); parent.lock.unlock(); // System.out.println("removed a node"); // need to update balance values here if (direction == Left) { parent.bal.lefth = n.bal.localh - 1; } else { parent.bal.righth = n.bal.localh - 1; } parent.updateLocalh(); return true; } int rightRotate(Node<K, V> parent, char direction, boolean doRotate) { Node<K, V> n, l, lr, r, newNode; if (parent.removed) return 0; n = direction == Left ? parent.left : parent.right; if (n == null) return 0; l = n.left; if (l == null) return 0; if (l.bal.lefth - l.bal.righth < 0 && !doRotate) { // should do a double rotate return 2; } if (allocateOutside) { newNode = new Node<K, V>(null, null); } parent.lock.lock(); n.lock.lock(); l.lock.lock(); lr = l.right; r = n.right; if (allocateOutside) { newNode.setupNode(n.key, Math.max(1 + l.bal.righth, 1 + n.bal.righth), l.bal.righth, n.bal.righth, n.value, lr, r); } else { newNode = new Node<K, V>(n.key, Math.max(1 + l.bal.righth, 1 + n.bal.righth), l.bal.righth, n.bal.righth, n.value, lr, r); } l.right = newNode; n.removed = true; if (direction == Left) { parent.left = l; } else { parent.right = l; } l.lock.unlock(); n.lock.unlock(); parent.lock.unlock(); // need to update balance values l.bal.righth = newNode.bal.localh; l.updateLocalh(); if (direction == Left) { parent.bal.lefth = l.bal.localh; } else { parent.bal.righth = l.bal.localh; } parent.updateLocalh(); if (STRUCT_MODS) { vars.rotations++; counts.get().structMods++; } // System.out.println("right rotate"); return 1; } int leftRotate(Node<K, V> parent, char direction, boolean doRotate) { Node<K, V> n, r, rl, l, newNode; if (parent.removed) return 0; n = direction == Left ? parent.left : parent.right; if (n == null) return 0; r = n.right; if (r == null) return 0; if (r.bal.lefth - r.bal.righth > 0 && !doRotate) { // should do a double rotate return 3; } if (allocateOutside) { newNode = new Node<K, V>(null, null); } parent.lock.lock(); n.lock.lock(); r.lock.lock(); rl = r.left; l = n.left; if (allocateOutside) { newNode.setupNode(n.key, Math.max(1 + r.bal.lefth, 1 + n.bal.lefth), n.bal.lefth, r.bal.lefth, n.value, l, rl); } else { newNode = new Node<K, V>(n.key, Math.max(1 + r.bal.lefth, 1 + n.bal.lefth), n.bal.lefth, r.bal.lefth, n.value, l, rl); } r.left = newNode; // temp (Need to fix this!!!!!!!!!!!!!!!!!!!!) n.right = parent; n.left = parent; n.removed = true; if (direction == Left) { parent.left = r; } else { parent.right = r; } r.lock.unlock(); n.lock.unlock(); parent.lock.unlock(); // need to update balance values r.bal.righth = newNode.bal.localh; r.updateLocalh(); if (direction == Left) { parent.bal.lefth = r.bal.localh; } else { parent.bal.righth = r.bal.localh; } parent.updateLocalh(); if (STRUCT_MODS) { vars.rotations++; counts.get().structMods++; } // System.out.println("left rotate"); return 1; } boolean propagate(Node<K, V> node) { Node<K, V> lchild, rchild; lchild = node.left; rchild = node.right; if (lchild == null) { node.bal.lefth = 0; } else { node.bal.lefth = lchild.bal.localh; } if (rchild == null) { node.bal.righth = 0; } else { node.bal.righth = rchild.bal.localh; } node.updateLocalh(); if (STRUCT_MODS) vars.propogations++; if (Math.abs(node.bal.righth - node.bal.lefth) >= 2) return true; return false; } boolean performRotation(Node<K, V> parent, char direction) { int ret; Node<K, V> node; ret = singleRotation(parent, direction, false, false); if (ret == 2) { // Do a LRR node = direction == Left ? parent.left : parent.right; ret = singleRotation(node, Left, true, false); if (ret > 0) { if (singleRotation(parent, direction, false, true) > 0) { // System.out.println("LRR"); } } } else if (ret == 3) { // Do a RLR node = direction == Left ? parent.left : parent.right; ret = singleRotation(node, Right, false, true); if (ret > 0) { if (singleRotation(parent, direction, true, false) > 0) { // System.out.println("RLR"); } } } if (ret > 0) return true; return false; } int singleRotation(Node<K, V> parent, char direction, boolean leftRotation, boolean rightRotation) { int bal, ret = 0; Node<K, V> node, child; node = direction == Left ? parent.left : parent.right; bal = node.bal.lefth - node.bal.righth; if (bal >= 2 || rightRotation) { // check reiable and rotate child = node.left; if (child != null) { if (node.bal.lefth == child.bal.localh) { ret = rightRotate(parent, direction, rightRotation); } } } else if (bal <= -2 || leftRotation) { // check reliable and rotate child = node.right; if (child != null) { if (node.bal.righth == child.bal.localh) { ret = leftRotate(parent, direction, leftRotation); } } } return ret; } boolean recursivePropagate(Node<K, V> parent, Node<K, V> node, char direction) { Node<K, V> left, right; if (node == null) return true; left = node.left; right = node.right; if (!node.removed && node.value == DELETED && (left == null || right == null) && node != this.root) { if (removeNode(parent, direction)) { return true; } } if (stop) { return true; } if (!node.removed) { if (left != null) { recursivePropagate(node, left, Left); } if (right != null) { recursivePropagate(node, right, Right); } } if (stop) { return true; } // no rotations for now if (!node.removed && node != this.root) { if (propagate(node)) { this.performRotation(parent, direction); } } return true; } public boolean stopMaintenance() { this.stop = true; try { this.mainThd.join(); } catch (InterruptedException e) { // TODO Auto-generated catch block e.printStackTrace(); } return true; } public boolean startMaintenance() { this.stop = false; mainThd = new MaintenanceThread(this); mainThd.start(); return true; } boolean doMaintenance() { while (!stop) { recursivePropagate(this.root, this.root.left, Left); } if (STRUCT_MODS) this.structMods += counts.get().structMods; System.out.println("Propogations: " + vars.propogations); System.out.println("Rotations: " + vars.rotations); return true; } // not thread safe public int getSize() { this.size = 0; recursiveGetSize(root.left); return size; } void recursiveGetSize(Node<K, V> node) { if (node == null) return; if (node.removed) { // System.out.println("Shouldn't find removed nodes in the get size function"); } if (node.value != DELETED) { this.size++; } recursiveGetSize(node.left); recursiveGetSize(node.right); } public int numNodes() { this.size = 0; ConcurrentHashMap<Integer, Node<K, V>> map = new ConcurrentHashMap<Integer, Node<K, V>>(); recursiveNumNodes(root.left, map); return size; } void recursiveNumNodes(Node<K, V> node, ConcurrentHashMap<Integer, Node<K, V>> map) { if (node == null) return; if (node.removed) { // System.out.println("Shouldn't find removed nodes in the get size function"); } Node<K, V> n = map.putIfAbsent((Integer) node.key, node); if (n != null) { System.out.println("Error: " + node.key); } this.size++; recursiveNumNodes(node.left, map); recursiveNumNodes(node.right, map); } public int getBalance() { int lefth = 0, righth = 0; if (root.left == null) return 0; lefth = recursiveDepth(root.left.left); righth = recursiveDepth(root.left.right); return lefth - righth; } int recursiveDepth(Node<K, V> node) { if (node == null) { return 0; } int lefth, righth; lefth = recursiveDepth(node.left); righth = recursiveDepth(node.right); return Math.max(lefth, righth) + 1; } @Override public void clear() { this.stopMaintenance(); this.resetTree(); this.startMaintenance(); return; } private void resetTree() { this.structMods = 0; this.vars.propogations = 0; this.vars.rotations = 0; root.left = null; } @Override public int size() { return this.getSize(); } public long getStructMods() { return structMods; } }