Why I Finally Learned Union-Find (After Avoiding It for Years)
I used to skip Union-Find problems in coding interviews. The name sounded intimidating, and I could usually brute-force graph connectivity with BFS/DFS. Then I hit a problem during a mock interview: “Find the number of connected components in an undirected graph with 100,000 nodes and edge insertions in real-time.” My DFS solution timed out spectacularly.
That’s when I realized Union-Find isn’t just another data structure—it’s the only practical way to handle dynamic connectivity at scale.
The Core Problem: Who’s Connected to Whom?
Here’s the setup: you have N isolated nodes. You start connecting them with edges, one at a time. At any point, you need to answer: “Are nodes A and B in the same connected component?”
Brute force? Run BFS/DFS every time someone asks. That’s O(V + E) per query. With thousands of queries, you’re toast.
Union-Find solves this in nearly O(1) per operation. Not theoretically O(1)—it’s O(α(N)), where α is the inverse Ackermann function. In practice, α(N) ≤ 4 for any N you’ll ever see. So yeah, basically constant time.
The Mental Model: Forest of Trees
Think of each connected component as a tree. Every node points to its parent. The root of the tree is the “representative” of that component.
When you ask “Are A and B connected?”, you climb up from A to its root, climb up from B to its root, and check if they’re the same. That’s the Find operation.
When you connect A and B, you merge their trees by making one root point to the other. That’s the Union operation.
Simple, right? The devil’s in the optimizations.
My First Attempt (That Worked But Sucked)
Here’s the naive version I coded up first:
class UnionFind:
def __init__(self, n):
self.parent = list(range(n)) # each node is its own parent initially
def find(self, x):
# climb up until we find the root
while self.parent[x] != x:
x = self.parent[x]
return x
def union(self, x, y):
root_x = self.find(x)
root_y = self.find(y)
if root_x != root_y:
self.parent[root_x] = root_y # attach x's tree under y's root
def connected(self, x, y):
return self.find(x) == self.find(y)
This works. It’s correct. But it’s also O(N) worst-case per operation if you build a degenerate tree (imagine chaining 0→1→2→3→…→N-1). Every find() becomes a linear scan.
I tested this on LeetCode problem 547 (Number of Provinces) with N=200. Passed. Then I tried a stress test with N=100,000 nodes arranged in a worst-case chain. 8 seconds for 100,000 queries.
Not good enough.
Optimization 1: Union by Rank
The problem is unbalanced trees. Solution: always attach the shorter tree under the taller one. Track each tree’s “rank” (roughly its height).
class UnionFind:
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n # initially all trees have rank 0
def find(self, x):
while self.parent[x] != x:
x = self.parent[x]
return x
def union(self, x, y):
root_x = self.find(x)
root_y = self.find(y)
if root_x == root_y:
return
# attach smaller rank tree under larger rank tree
if self.rank[root_x] < self.rank[root_y]:
self.parent[root_x] = root_y
elif self.rank[root_x] > self.rank[root_y]:
self.parent[root_y] = root_x
else:
self.parent[root_y] = root_x
self.rank[root_x] += 1 # only increment when tying
This guarantees tree height ≤ log(N). Much better. My stress test dropped to 0.3 seconds.
But we can do even better.
Optimization 2: Path Compression (The Magic Trick)
Here’s the insight that blew my mind: during find(x), you’re climbing from x to the root. Why not rewire every node along that path to point directly to the root? Next time you query any of those nodes, it’s a one-hop lookup.
The recursive version is elegant:
def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x]) # recursive path compression
return self.parent[x]
That’s it. One line. This flattens the tree almost completely over time.
Warning: The recursive version can hit Python’s recursion limit (default 1000) on degenerate chains. In coding interviews, this is a real risk—large test cases can create deep chains before path compression kicks in. There are two ways to handle this:
Option A: Set the recursion limit explicitly (quick fix, but not always allowed in interviews):
import sys
sys.setrecursionlimit(200001) # set higher than max N
Option B (Recommended): Use iterative path compression—no recursion limit issues at all:
def find(self, x):
# Step 1: Find the root
root = x
while self.parent[root] != root:
root = self.parent[root]
# Step 2: Compress the path
while self.parent[x] != root:
next_parent = self.parent[x]
self.parent[x] = root
x = next_parent
return root
There’s also path halving, a simpler iterative variant that achieves the same amortized complexity:
def find(self, x):
while self.parent[x] != x:
self.parent[x] = self.parent[self.parent[x]] # skip one level
x = self.parent[x]
return x
Path halving does fewer pointer updates per call than full path compression, but over many operations both achieve O(α(N)) amortized. Path halving is my go-to for interviews because it’s iterative, concise, and safe.
Combine union by rank + path compression (any variant), and you get O(α(N)) per operation. My stress test now runs in 0.08 seconds for 100,000 operations.
The Final Implementation
Here’s the version I use in interviews now. I use iterative path compression to avoid recursion limit issues:
class UnionFind:
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n
self.count = n # number of disjoint sets (useful for counting components)
def find(self, x):
# iterative path compression (no recursion limit issues)
root = x
while self.parent[root] != root:
root = self.parent[root]
while self.parent[x] != root:
next_parent = self.parent[x]
self.parent[x] = root
x = next_parent
return root
def union(self, x, y):
root_x, root_y = self.find(x), self.find(y)
if root_x == root_y:
return False # already connected
# union by rank
if self.rank[root_x] < self.rank[root_y]:
self.parent[root_x] = root_y
elif self.rank[root_x] > self.rank[root_y]:
self.parent[root_y] = root_x
else:
self.parent[root_y] = root_x
self.rank[root_x] += 1
self.count -= 1 # merged two sets
return True
def connected(self, x, y):
return self.find(x) == self.find(y)
def get_count(self):
return self.count # number of connected components
# Example: Number of Provinces (LeetCode 547)
def findCircleNum(isConnected):
n = len(isConnected)
uf = UnionFind(n)
for i in range(n):
for j in range(i + 1, n):
if isConnected[i][j] == 1:
uf.union(i, j)
return uf.get_count()
# Test
graph = [
[1, 1, 0],
[1, 1, 0],
[0, 0, 1]
]
print(findCircleNum(graph)) # Output: 2 (two separate components)
Time complexity: O(N² × α(N)) for the nested loop (checking all pairs). Space: O(N).
Union by Rank vs. Union by Size
You’ll see two variants in the wild: union by rank and union by size. Both achieve O(α(N)) amortized complexity, but they differ in a practical way.
Union by rank tracks the upper bound of tree height. After path compression, rank no longer reflects the true height—it’s just a heuristic. You can’t use it to answer “how big is this component?”
Union by size tracks the actual number of nodes in each component. This is useful when interview problems ask follow-up questions like “what’s the size of the largest component?” or “return the component sizes.” Here’s what it looks like:
class UnionFindBySize:
def __init__(self, n):
self.parent = list(range(n))
self.size = [1] * n # each node starts as a component of size 1
def find(self, x):
root = x
while self.parent[root] != root:
root = self.parent[root]
while self.parent[x] != root:
next_parent = self.parent[x]
self.parent[x] = root
x = next_parent
return root
def union(self, x, y):
root_x, root_y = self.find(x), self.find(y)
if root_x == root_y:
return False
# attach smaller tree under larger tree
if self.size[root_x] < self.size[root_y]:
self.parent[root_x] = root_y
self.size[root_y] += self.size[root_x]
else:
self.parent[root_y] = root_x
self.size[root_x] += self.size[root_y]
return True
def get_size(self, x):
return self.size[self.find(x)] # size of x's component
Use union by size when you need component sizes. Use union by rank when you don’t—it’s marginally simpler.
Edge Cases That Burned Me
-
Self-loops: Connecting a node to itself. Your
union()should checkif root_x == root_yand return early. I forgot this once and decrementedcountincorrectly. -
Duplicate edges: Calling
union(a, b)multiple times. Not a bug, but wasteful. ReturnFalseif already connected to signal “no-op.” -
0-indexed vs 1-indexed: Some problems give you 1-indexed nodes. I’ve debugged “index out of range” errors more times than I’d like to admit. Always check the problem constraints.
-
Empty input: If the graph has 0 nodes or 0 edges, make sure your implementation handles it gracefully.
UnionFind(0)should work without errors.
Real-World Use Case: Kruskal’s Minimum Spanning Tree
This is where Union-Find really shines. Kruskal’s algorithm finds the minimum spanning tree (MST) of a weighted graph:
- Sort all edges by weight (ascending)
- For each edge (u, v, weight):
– If u and v are NOT connected, add this edge to MST and union(u, v)
– Otherwise, skip (would create a cycle) - Stop when you’ve added N-1 edges
Union-Find is perfect here because it prevents cycles in O(α(N)) time.
def kruskal_mst(n, edges):
"""
n: number of nodes (0 to n-1)
edges: list of (u, v, weight) tuples
Returns: list of edges in MST and total weight
"""
uf = UnionFind(n)
edges.sort(key=lambda x: x[2]) # sort by weight
mst = []
total_weight = 0
for u, v, weight in edges:
if uf.union(u, v): # returns True if union happened (not already connected)
mst.append((u, v, weight))
total_weight += weight
if len(mst) == n - 1: # MST complete
break
return mst, total_weight
# Example graph
edges = [
(0, 1, 4),
(0, 2, 3),
(1, 2, 1),
(1, 3, 2),
(2, 3, 4)
]
mst, weight = kruskal_mst(4, edges)
print(f"MST edges: {mst}") # [(1, 2, 1), (1, 3, 2), (0, 2, 3)]
print(f"Total weight: {weight}") # 6
Time complexity: O(E log E) for sorting + O(E × α(V)) for Union-Find operations. Space: O(V).
The sorting dominates, so overall it’s O(E log E). Note that O(E log E) = O(E log V) since E ≤ V², meaning log E ≤ 2 log V = O(log V). This equivalence comes up in interviews, so it’s worth knowing.
Without Union-Find’s near-constant cycle detection, you’d need DFS per edge—that’s O(E × (V + E)), which is brutal for dense graphs.
When NOT to Use Union-Find
Union-Find doesn’t track paths, only connectivity. If you need the actual shortest path between nodes, use Dijkstra or BFS. Union-Find just tells you “yes, they’re reachable” or “no, they’re not.”
For static graphs where you just need connected components once (no dynamic edge additions), a simple DFS/BFS is O(V + E) and perfectly fine. Union-Find’s advantage is handling incremental edge additions efficiently.
For small graphs (N < 100), BFS/DFS is simpler and fast enough. Don’t overcomplicate things.
Practice Problems
Once you’ve got the template down, try these LeetCode problems in order of increasing difficulty:
- 547. Number of Provinces — Direct application of Union-Find for connected components
- 684. Redundant Connection — Find the edge that creates a cycle (Union-Find returns
False) - 200. Number of Islands — Can be solved with Union-Find on a 2D grid (though BFS/DFS is more natural)
- 1971. Find if Path Exists in Graph — Basic connectivity check
- 721. Accounts Merge — Union-Find on strings, a great intermediate challenge
- 1584. Min Cost to Connect All Points — Kruskal’s MST in action
The One Thing I Wish I’d Known Earlier
Path compression is non-intuitive the first time you see it. It feels like you’re mutating the tree mid-query, which seems… wrong? But it’s not. The tree structure doesn’t matter—only the root matters. Flattening it makes future queries faster without breaking correctness.
Once that clicked, everything else fell into place.
Use Union-Find for dynamic connectivity, cycle detection, and MST. If you’re doing 1000+ connectivity queries on a graph, this is the tool. For one-off “are these two nodes connected?” questions, BFS is fine.
Did you find this helpful?
☕ Buy me a coffee
Leave a Reply