Union Find

Introduction

Today we often face a certain kind of problem called dynamic connectivity:
Given a set of N objects.

  • Union command: connect two objects.
  • Find/connected query: is there a path connecting the two objects?

So there are 2 types of operation in this:

  • Find query. Check if two objects are in the same component.
  • Union command. Replace components containing two objects with their union.

Problem Statement

First read in an integer N, each object is less than N (for each element, the minimum value is 0, the maximum value is N-1).
Second, read in an integer T (the number of pairs that need to be unioned), followed by T lines of pairs.
Last, read in an integer TNUM (the number of find queries), followed by TNUM lines of pairs.

If they are not yet connected, connect them. Then we will execute find() queries on the data, if they are connected, output YES, otherwise output NO.
Here is the sample input:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
10 11
4 3
3 8
6 5
9 4
2 1
8 9
5 0
7 2
6 1
1 0
6 7
3
2 3
3 4
4 5

Here is the sample output:

1
2
3
NO
YES
NO

Quick-Find

This is a naive algorithm that we could thought of. We have an integer array id[] of length N. If 2 elements p and q are connected, they should have the same id. That is to say, id[p] == id[q]. So the find operation is to check if p and q have the same id. And the union operation is to merge components containing p and q, change all entries whose id equals id[p] to id[q].
Here is my implementation in C++ code:

quick-find.cpp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include <iostream>
#include <vector>

using namespace std;

vector<int> id;

// find if 2 elements are connected or not
bool find(int p, int q) {
return id[p] == id[q];
}

// union operation
void unionOp(int p, int q) {
int pid = id[p];
int qid = id[q];
for (int i = 0; i < id.size(); i++) {
if (pid == id[i]) {
id[i] = qid;
}
}
}

// check if it is a valid input
bool isValid(int p) {
int N = id.size();
if (p >= 0 && p < N) {
return true;
} else {
return false;
}
}

int main() {
int T, N, TNUM;
scanf("%d %d", &N, &T);

// initialize the id array
for (int i = 0; i < N; i++) {
id.push_back(i);
}

for (int i = 0; i < T; i++) {
int p, q;
scanf("%d %d", &p, &q);
if (isValid(p) && isValid(q)) {
unionOp(p, q);
} else {
printf("Your input is not less than N, please re enter your input.");
}
}

scanf("%d", &TNUM);
for (int i = 0; i < TNUM; i++) {
int p, q;
scanf("%d %d", &p, &q);
printf("%s\n", (find(p, q)==true)?"YES":"NO");
}
return 0;
}

The defect of this algorithm is that it is too slow, if there are N union commands on N objects, then the time cost of union operation is O(n×n).

Quick-Union

Here is a better algorithm, we still have an integer array id[] of length N, and id[i] is the parent of i.

  • Find. Check if p and q have the same root.
  • Union. To merge components containing p and q,
    set the id of p’s root to the id of q’s root.
    Quick-Union

Here is my implementation in C++ code:

quick-union.cpp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#include <iostream>
#include <vector>

using namespace std;

vector<int> id;

// find the root of p
int root(int p) {
while (id[p] != p) {
p = id[p];
}
return p;
}

// find if 2 elements are connected or not
bool find(int p, int q) {
return root(p) == root(q);
}

// union operation
void unionOp(int p, int q) {
int i = root(p);
int j = root(q);
id[i] = j;
}

// check if it is a valid input
bool isValid(int p) {
int N = id.size();
if (p >= 0 && p < N) {
return true;
} else {
return false;
}
}

int main() {
int T, N, TNUM;
scanf("%d %d", &N, &T);

// initialize the id array
for (int i = 0; i < N; i++) {
id.push_back(i);
}

for (int i = 0; i < T; i++) {
int p, q;
scanf("%d %d", &p, &q);
if (isValid(p) && isValid(q)) {
unionOp(p, q);
} else {
printf("Your input is not less than N, please re enter your input.");
}
}

scanf("%d", &TNUM);
for (int i = 0; i < TNUM; i++) {
int p, q;
scanf("%d %d", &p, &q);
printf("%s\n", (find(p, q)==true)?"YES":"NO");
}
return 0;
}

But this algorithm has 2 defects:

  • Trees can get tall.
  • Find too expensive (could be N array accesses).

Weighted Quick-Union

Here is an improved quick-union algorithm called Weighted Quick-Union.
Rather than arbitrarily connecting the second tree to the first for union() in the quick-union algorithm, we keep track of the size of each tree and always connect the smaller tree to the larger one.
Quick-Union

I have implemented this algorithm in both C++ and Java. See below.

Here is my implementation in Java code:

WeightedQuickUnion.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import java.util.Scanner;

public class WeightedQuickUnion {
private int[] parent; // parent[i] = parent of i
private int[] size; // size[i] = number of sites in subtree rooted at i

// initialization
public WeightedQuickUnion(int N) {
parent = new int[N];
size = new int[N];
for (int i = 0; i < N; i++) {
parent[i] = i;
size[i] = 1;
}
}

// find the root of p
public int root(int p) {
while (p != parent[p]) {
p = parent[p];
}
return p;
}

// find if 2 elements are connected or not
public boolean find(int p, int q) {
return root(p) == root(q);
}

public void union(int p, int q) {
int i = root(p);
int j = root(q);
if (i == j) {
return;
}
if (size[i] > size[j]) {
parent[j] = i;
size[i] += size[j];
} else {
parent[i] = j;
size[j] += size[i];
}
}

// check if it is a valid input
public boolean isValid(int p) {
int N = parent.length;
if (p >= 0 && p < N) {
return true;
} else {
return false;
}
}

public static void main(String[] args) {
String str = readDataFromConsole("Please input N: ");
Integer N = Integer.parseInt(str);

WeightedQuickUnion wqu = new WeightedQuickUnion(N);

str = readDataFromConsole("Please input T: ");
Integer T = Integer.parseInt(str);

for (Integer i = 0; i < T; i++) {
str = readDataFromConsole("Please input p: ");
Integer p = Integer.parseInt(str);
str = readDataFromConsole("Please input q: ");
Integer q = Integer.parseInt(str);
if (wqu.isValid(p) && wqu.isValid(q)) {
wqu.union(p, q);
} else {
System.out.println("Your input is not less than N, please re enter your input.");
}
}

str = readDataFromConsole("Please input the number of find operations TNUM: ");
Integer TNUM = Integer.parseInt(str);
for (int i = 0; i < TNUM; i++) {
str = readDataFromConsole("Please input p: ");
Integer p = Integer.parseInt(str);
str = readDataFromConsole("Please input q: ");
Integer q = Integer.parseInt(str);
System.out.println((wqu.find(p, q)==true)?"YES":"NO");
}
}

private static String readDataFromConsole(String prompt) {
Scanner scanner = new Scanner(System.in);
System.out.print(prompt);
return scanner.next();
}
}

Here is my implementation in C++ code:

weighted-quick-union.cpp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include <iostream>
#include <vector>

using namespace std;

vector<int> parent;
vector<int> sz;

// find the root of p
int root(int p) {
while (parent[p] != p) {
p = parent[p];
}
return p;
}

// find if 2 elements are connected or not
bool find(int p, int q) {
return root(p) == root(q);
}

// union operation
void unionOp(int p, int q) {
int i = root(p);
int j = root(q);
if (i == j) {
return;
}
if (sz[i] < sz[j]) {
parent[i] = j;
sz[j] += sz[i];
} else {
parent[j] = i;
sz[i] += sz[j];
}
}

// check if it is a valid input
bool isValid(int p) {
int N = parent.size();
if (p >= 0 && p < N) {
return true;
} else {
return false;
}
}

int main() {
int T, N, TNUM;
scanf("%d %d", &N, &T);

// initialize the parent array and sz array
for (int i = 0; i < N; i++) {
parent.push_back(i);
sz.push_back(1);
}

for (int i = 0; i < T; i++) {
int p, q;
scanf("%d %d", &p, &q);
if (isValid(p) && isValid(q)) {
unionOp(p, q);
} else {
printf("Your input is not less than N, please re enter your input.");
}
}

scanf("%d", &TNUM);
for (int i = 0; i < TNUM; i++) {
int p, q;
scanf("%d %d", &p, &q);
printf("%s\n", (find(p, q)==true)?"YES":"NO");
}
return 0;
}

Now the time cost of find() is O(logn), the time cost of union operation is also O(logn), which is much faster than the former O(n) version algorithm.