package ch.akuhn.graph2; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.TreeMap; import org.junit.Test; public class MergeFind { int[] parent; public MergeFind(int size) { parent = new int[size]; for (int n = 0; n < parent.length; n++) parent[n] = n; } public int find(int n) { if (parent[n] == n) return n; // Use path compression for speed return parent[n] = find(parent[n]); } public void merge(int a, int b) { if (parent[a] != a) throw new IllegalArgumentException(); if (parent[b] != b) throw new IllegalArgumentException(); parent[a] = b; } public int setCount() { int count = 0; for (int n = 0; n < parent.length; n++) if (parent[n] == n) count++; return count; } public int size() { return parent.length; } public String toString() { Map<Integer, List<Integer>> map = new TreeMap(); for (int n = 0; n < parent.length; n++) { if (!map.containsKey(find(n))) map.put(find(n), new ArrayList()); map.get(find(n)).add(n); } StringBuilder s = new StringBuilder(); for (int each: map.keySet()) { s.append(map.get(each)); } return s.toString(); } public static class Examples { @Test public void shouldCreateDisjointSets() { MergeFind mf = new MergeFind(7); assertEquals(7, mf.size()); assertEquals(7, mf.setCount()); } @Test public void shouldMergeSets() { MergeFind mf = new MergeFind(7); mf.merge(mf.find(4), mf.find(6)); mf.merge(mf.find(5), mf.find(0)); mf.merge(mf.find(3), mf.find(6)); mf.merge(mf.find(1), mf.find(2)); String s = mf.toString(); assertTrue(s.contains("[0, 5]")); assertTrue(s.contains("[1, 2]")); assertTrue(s.contains("[3, 4, 6]")); assertEquals(3, mf.setCount()); } } }