package com.freetymekiyan.algorithms.level.medium; import com.freetymekiyan.algorithms.utils.Utils; import com.freetymekiyan.algorithms.utils.Utils.TreeNode; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import java.util.ArrayDeque; import java.util.Deque; /** * Given a binary search tree, write a function kthSmallest to find the kth smallest element in it. * <p> * Note: * You may assume k is always valid, 1 ≤ k ≤ BST's total elements. * <p> * Follow up: * What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently? How * would you optimize the kthSmallest routine? * <p> * Hint: * <p> * Try to utilize the property of a BST. * What if you could modify the BST node's structure? * The optimal runtime complexity is O(height of BST). * <p> * Tags: Binary Search Tree * Similar Problems: (M) Binary Tree Inorder Traversal * <p> * Answers: * If the BST is modified often, and I can modify the BST node's structure, I can add the left subtree node count to the * each node to know its ranking. Then use binary search to find the target node. */ public class KthSmallestElementInABst { private KthSmallestElementInABst k; private int res; private int count; /** * Recursive solution with in-order traversal helper. */ public int kthSmallest(TreeNode root, int k) { count = k; traverse(root); return res; } private void traverse(TreeNode node) { if (node.left != null) { traverse(node.left); } count--; if (count == 0) { res = node.val; return; } if (node.right != null) { traverse(node.right); } } /** * Iterasive solution with stack. */ public int kthSmallestB(TreeNode root, int k) { Deque<TreeNode> stack = new ArrayDeque<>(); int count = k; while (!stack.isEmpty() || root != null) { if (root != null) { stack.push(root); root = root.left; } else { root = stack.pop(); count--; if (count == 0) { return root.val; } root = root.right; } } return -1; } /** * Binary search for left subtree node count. * * For BST, the # of nodes of left subtree is actually the node's ranking. */ public int kthSmallestC(TreeNode root, int k) { int count = countNodes(root.left); if (k <= count) { return kthSmallest(root.left, k); } else if (k > count + 1) { return kthSmallest(root.right, k - 1 - count); // 1 is counted as current node } return root.val; } /** * Count how many nodes in this subtree rooted from n. * If we can modify the data structure, we can save the count with each node. */ private int countNodes(TreeNode n) { if (n == null) { return 0; } return 1 + countNodes(n.left) + countNodes(n.right); } @Before public void setUp() { k = new KthSmallestElementInABst(); } @Test public void testExamples() { TreeNode root = Utils.buildBinaryTree(new Integer[]{1}); Assert.assertEquals(1, k.kthSmallest(root, 1)); Assert.assertEquals(1, k.kthSmallestB(root, 1)); Assert.assertEquals(1, k.kthSmallestC(root, 1)); root = Utils.buildBinaryTree(new Integer[]{2, 1}); Assert.assertEquals(1, k.kthSmallest(root, 1)); Assert.assertEquals(1, k.kthSmallestB(root, 1)); Assert.assertEquals(1, k.kthSmallestC(root, 1)); root = Utils.buildBinaryTree(new Integer[]{1, -1, 2, null, null, null, 3}); Assert.assertEquals(3, k.kthSmallest(root, 4)); Assert.assertEquals(3, k.kthSmallestB(root, 4)); Assert.assertEquals(3, k.kthSmallestC(root, 4)); } @After public void tearDown() { k = null; } }