/* * AvlTreeTest.java * * Copyright (C) 2014 Leo Osvald <leo.osvald@gmail.com> * * This file is part of YOUR PROGRAM NAME. * * YOUR PROGRAM NAME is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * YOUR PROGRAM NAME is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with YOUR PROGRAM NAME. If not, see <http://www.gnu.org/licenses/>. */ package org.sglj.util.struct; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Random; import java.util.TreeSet; import org.junit.Test; import org.sglj.util.struct.AvlTree.AvlNode; import org.sglj.util.struct.AvlTree.InsertResult; import org.sglj.util.struct.AvlTree.RemoveResult; public class AvlTreeTest { static class Node extends AvlNode<Integer> { Integer key; int size; int heightMin; int heightMax; public Node(Integer key) { this.key = key; size = 1; heightMin = heightMax = 0; } @Override public Node mergeCodomains(AvlNode<Integer> a, AvlNode<Integer> b) { size = 1 + (a != null ? ((Node)a).size : 0) + (b != null ? ((Node)b).size : 0); heightMin = 1 + Math.min(a != null ? ((Node)a).heightMin : -1, b != null ? ((Node)b).heightMin : -1); heightMax = 1 + Math.max(a != null ? ((Node)a).heightMax : -1, b != null ? ((Node)b).heightMax : -1); return this; } @Override public Integer getKey() { return key; } @Override public String toString() { return "[" + key + "]"; } Node l() { return (Node)left; } Node r() { return (Node)right; } } static class VerifiedAvlTree extends AvlTree<Integer> { TreeSet<Integer> treeSet = new TreeSet<Integer>(); @Override protected Node createNode(Integer key) { return new Node(key); } @Override public Node getRoot() { return (Node)super.getRoot(); } @Override public boolean add(Integer e) { int oldHeight = getHeight(); boolean ret = super.add(e); int heightDiff = getHeight() - oldHeight; if (heightDiff > 0) assertEquals(1, heightDiff); else if (heightDiff < -1) fail("Insertion cannot decrease height by >1"); assertHeightInvariant(); assertEquals(treeSet.add(e), ret); return ret; } @Override public boolean remove(Object o) { int oldHeight = getHeight(); boolean ret = super.remove(o); int heightDiff = getHeight() - oldHeight; if (heightDiff < 0) assertEquals(-1, heightDiff); else if (heightDiff > 1) fail("Removal cannot increase height by >1"); assertHeightInvariant(); assertEquals(treeSet.remove(o), ret); return ret; } @Override public String toString() { String ret = super.toString(); String exp = treeSet.toString(); boolean ok = ret.equals(exp); if (!ok) { System.out.println("failed toString():\n" + repr(this)); assertEquals(exp, ret); } return ret; } int getHeight() { return getRoot() != null ? getRoot().heightMax : -1; } private void assertHeightInvariant() { Node root = getRoot(); if (root == null) return; assertHeightInvariant(root); // http://en.wikipedia.org/wiki/AVL_tree#Comparison_to_other_structures assertTrue(Math.pow(2.0, (root.heightMax + 0.328) / 1.45) - 2 < root.size); } private static void assertHeightInvariant(Node node) { if (node == null) return; int heightDiff = Math.abs( (node.r() != null ? node.r().heightMax : 0) - (node.l() != null ? node.l().heightMax : 0)); if (heightDiff != 0) assertEquals(1, heightDiff); assertHeightInvariant(node.l()); assertHeightInvariant(node.r()); } } static <E> String repr(AvlTree<E> t) { StringBuilder sb = new StringBuilder(); repr(t.getRoot(), sb); return sb.toString(); } private static void repr(AvlNode<?> node, StringBuilder sb) { sb.append('('); if (node != null) { sb.append(node); sb.append(" <"); repr(node.left, sb); sb.append(" >"); repr(node.right, sb); } sb.append(')'); } static <E> void assertRepr(AvlTree<E> t, String expected) { assertEquals(expected, repr(t)); } @Test public void testEmpty() { VerifiedAvlTree vat = new VerifiedAvlTree(); assertTrue(vat.isEmpty()); assertEquals(0, vat.size()); assertEquals("()", repr(vat)); Iterator<Integer> it = vat.iterator(); assertFalse(it.hasNext()); RemoveResult<Integer> rr = new RemoveResult<Integer>(); rr.value = 17; vat.remove(42, rr); assertNull(rr.removed); assertEquals(0, rr.value); } @Test public void testInsertRemoveRoot() { VerifiedAvlTree vat = new VerifiedAvlTree(); InsertResult<Integer> ir = new InsertResult<Integer>(); ir.value = 17; vat.insert(42, ir); assertFalse(vat.isEmpty()); assertEquals(1, vat.size()); assertEquals((Integer)42, ir.inserted.getKey()); Iterator<Integer> it = vat.iterator(); assertTrue(it.hasNext()); assertEquals((Integer)42, it.next()); assertFalse(it.hasNext()); assertRepr(vat, "([42] <() >())"); RemoveResult<Integer> rr = new RemoveResult<Integer>(); rr.value = 17; vat.remove(42, rr); assertNotNull(rr.removed); assertEquals((Integer)42, rr.removed.getKey()); assertEquals(1, rr.value); vat.add(13); assertEquals(1, vat.size()); assertEquals((Integer)13, vat.iterator().next()); vat.remove(13); assertEquals(0, vat.size()); assertFalse(vat.iterator().hasNext()); } @Test public void testInsertNoRebalance() { VerifiedAvlTree vat = new VerifiedAvlTree(); vat.add(50); vat.add(20); assertRepr(vat, "([50] <([20] <() >()) >())"); assertEquals(2, vat.getRoot().size); assertEquals(0, vat.getRoot().heightMin); assertEquals(1, vat.getRoot().heightMax); vat.add(80); assertRepr(vat, "([50] <([20] <() >()) >([80] <() >()))"); assertEquals(1, vat.getRoot().heightMin); assertEquals(1, vat.getRoot().heightMax); vat.add(25); assertRepr(vat, "([50] <([20] <() >([25] <() >())) >([80] <() >()))"); assertEquals(2, vat.getRoot().heightMax); assertEquals(2, vat.getRoot().l().size); } @Test public void testRemoveNoRebalance() { VerifiedAvlTree vat = new VerifiedAvlTree(); // remove a root with only the right child vat.add(40); vat.add(60); vat.remove(40); assertRepr(vat, "([60] <() >())"); assertEquals(1, vat.getRoot().size); // remove a node with only the left child vat.add(80); vat.add(41); vat.add(30); assertEquals(2, vat.getRoot().heightMax); vat.remove(41); assertRepr(vat, "([60] <([30] <() >()) >([80] <() >()))"); assertEquals(3, vat.getRoot().size); assertEquals(1, vat.getRoot().heightMin); assertEquals(1, vat.getRoot().heightMax); // remove a node with 2 children whose successor is right-left grandchild vat.add(20); vat.add(51); vat.add(90); vat.add(70); vat.add(31); assertRepr(vat, "([60]" + " <([30] <([20] <() >()) >([51] <([31] <() >()) >()))" + " >([80] <([70] <() >()) >([90] <() >())))"); assertEquals(4, vat.getRoot().l().size); assertEquals(0, vat.getRoot().l().r().heightMin); vat.remove(30); assertRepr(vat, "([60]" + " <([31] <([20] <() >()) >([51] <() >()))" + " >([80] <([70] <() >()) >([90] <() >())))"); assertEquals(1, vat.getRoot().l().r().size); assertEquals(1, vat.getRoot().l().heightMin); assertEquals(1, vat.getRoot().l().heightMax); assertEquals(1, vat.getRoot().r().heightMin); assertEquals(1, vat.getRoot().r().heightMax); assertEquals(3, vat.getRoot().l().size); assertEquals(7, vat.getRoot().size); // remove a leaf node which is a left child vat.remove(70); assertRepr(vat, "([60]" + " <([31] <([20] <() >()) >([51] <() >()))" + " >([80] <() >([90] <() >())))"); assertEquals(0, vat.getRoot().r().heightMin); assertEquals(1, vat.getRoot().r().heightMax); assertEquals(2, vat.getRoot().r().size); // remove a node with 2 children whose successor is right child assertEquals(3, vat.getRoot().l().size); vat.remove(31); assertRepr(vat, "([60]" + " <([51] <([20] <() >()) >())" + " >([80] <() >([90] <() >())))"); assertEquals(2, vat.getRoot().l().size); assertEquals(0, vat.getRoot().l().heightMin); assertEquals(1, vat.getRoot().l().heightMax); assertEquals(1, vat.getRoot().heightMin); // remove a node with only the right child assertEquals(1, vat.getRoot().r().r().size); vat.remove(80); assertRepr(vat, "([60]" + " <([51] <([20] <() >()) >())" + " >([90] <() >()))"); assertEquals(1, vat.getRoot().r().size); assertEquals(0, vat.getRoot().r().heightMax); // remove a leaf node which is a right child vat.remove(51); assertRepr(vat, "([60] <([20] <() >()) >([90] <() >()))"); assertEquals(1, vat.getRoot().heightMin); // remove the root with 2 children vat.remove(60); assertRepr(vat, "([90] <([20] <() >()) >())"); assertEquals(0, vat.getRoot().heightMin); } @Test public void testRotateLeft() { VerifiedAvlTree vat = new VerifiedAvlTree(); for (int key : new int[]{1, 2, 3}) vat.add(key); assertRepr(vat, "([2] <([1] <() >()) >([3] <() >()))"); vat.add(4); vat.add(5); assertRepr(vat, "([2] <([1] <() >()) >([4] <([3] <() >()) >([5] <() >())))"); vat.add(6); assertRepr(vat, "([4] <([2] <([1] <() >()) >([3] <() >())) >([5] <() >([6] <() >())))"); } @Test public void testRotateRight() { VerifiedAvlTree vat = new VerifiedAvlTree(); for (int key : new int[]{-1, -2, -3}) vat.add(key); assertRepr(vat, "([-2] <([-3] <() >()) >([-1] <() >()))"); vat.add(-4); vat.add(-5); assertRepr(vat, "([-2] <([-4] <([-5] <() >()) >([-3] <() >())) >([-1] <() >()))"); vat.add(-6); assertRepr(vat, "([-4] <([-5] <([-6] <() >()) >()) >([-2] <([-3] <() >()) >([-1] <() >())))"); } @Test public void testRotateLeftRight() { VerifiedAvlTree vat = new VerifiedAvlTree(); for (int key : new int[]{2, 8, 5}) vat.add(key); assertRepr(vat, "([5] <([2] <() >()) >([8] <() >()))"); for (int key : new int[]{1, 4, 6, 9, 3}) vat.add(key); assertRepr(vat, "([5] <([2] <([1] <() >()) >([4] <([3] <() >()) >())) >([8] <([6] <() >()) >([9] <() >())))"); vat.remove(1); assertRepr(vat, "([5] <([3] <([2] <() >()) >([4] <() >())) >([8] <([6] <() >()) >([9] <() >())))"); } @Test public void testRotateRightLeft() { VerifiedAvlTree vat = new VerifiedAvlTree(); for (int key : new int[]{8, 4, 2}) vat.add(key); assertRepr(vat, "([4] <([2] <() >()) >([8] <() >()))"); vat.add(6); vat.add(3); vat.add(1); vat.add(9); vat.add(7); assertRepr(vat, "([4] <([2] <([1] <() >()) >([3] <() >())) >([8] <([6] <() >([7] <() >())) >([9] <() >())))"); vat.remove(9); assertRepr(vat, "([4] <([2] <([1] <() >()) >([3] <() >())) >([7] <([6] <() >()) >([8] <() >())))"); } private static void permute(String prefix, String str, List<int[]> perms) { int n = str.length(); if (n == 0) { int[] a = new int[prefix.length()]; for (int i = 0; i < a.length; ++i) a[i] = prefix.charAt(i) - '0'; perms.add(a); } else { for (int i = 0; i < n; i++) permute(prefix + str.charAt(i), str.substring(0, i) + str.substring(i+1, n), perms); } } static List<int[]> permute(int maxN) { List<int[]> perms = new ArrayList<int[]>(); StringBuilder sb = new StringBuilder(); for (int i = 0; i <= maxN; ++i) { sb.append((char)('0' + i)); permute("", sb.toString(), perms); } return perms; } @Test public void testSearchAllPermutations() { final int n = 6; for (int[] keys : permute(n)) { VerifiedAvlTree vat = new VerifiedAvlTree(); boolean[] contained = new boolean[n + 1]; for (int key : keys) { vat.add(key); contained[key] = true; } for (int key = 0; key < n; ++key) assertEquals(contained[key], vat.contains(key)); } } @Test public void testInsertAllPermutations() { for (int[] keys : permute(7)) { VerifiedAvlTree vat = new VerifiedAvlTree(); for (int key : keys) vat.add(key); } } @Test public void testRemoveAllPermutations() { for (int[] keys : permute(7)) { for (int toRemove : keys) { VerifiedAvlTree vat = new VerifiedAvlTree(); for (int key : keys) vat.add(key); vat.remove(toRemove); } } } @Test public void testStressRandom() { final int sizeAvg = 100; final int coordMax = 400; final Random r = new Random(); VerifiedAvlTree vat = new VerifiedAvlTree(); for (int itr = 0; itr < 10000; ++itr) { boolean add; int sizeDiff = vat.size() - sizeAvg; if (sizeDiff >= 10) add = false; else if (sizeDiff < -10) add = true; else add = r.nextBoolean(); if (add) { int x = r.nextInt(coordMax); //System.out.println("+ " + x); vat.add(x); } else { int x; do { x = r.nextInt(coordMax); } while (!vat.remove(x)); //System.out.println("- " + x); } vat.toString(); } } }