|
@@ -0,0 +1,62 @@
|
|
|
+package main
|
|
|
+
|
|
|
+import "fmt"
|
|
|
+
|
|
|
+type disjointSet struct {
|
|
|
+ par []int
|
|
|
+ rank []int
|
|
|
+}
|
|
|
+
|
|
|
+func newDisjointSet(n int) disjointSet {
|
|
|
+ ds := disjointSet{make([]int, n), make([]int, n)}
|
|
|
+ for i := range ds.par {
|
|
|
+ ds.par[i] = i
|
|
|
+ }
|
|
|
+ return ds
|
|
|
+}
|
|
|
+
|
|
|
+func (ds disjointSet) find(i int) int {
|
|
|
+ if ds.par[i] == i {
|
|
|
+ return i
|
|
|
+ }
|
|
|
+ pa := ds.find(ds.par[i])
|
|
|
+ ds.par[i] = pa
|
|
|
+ return pa
|
|
|
+}
|
|
|
+
|
|
|
+func (ds disjointSet) isConnected(i, j int) bool {
|
|
|
+ return ds.find(i) == ds.find(j)
|
|
|
+}
|
|
|
+
|
|
|
+func (ds disjointSet) union(i, j int) bool {
|
|
|
+ i, j = ds.find(i), ds.find(j)
|
|
|
+ if i == j {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ if ds.rank[i] < ds.rank[j] {
|
|
|
+ ds.par[i] = j
|
|
|
+ } else {
|
|
|
+ if ds.rank[i] == ds.rank[j] {
|
|
|
+ ds.rank[i]++
|
|
|
+ }
|
|
|
+ ds.par[j] = i
|
|
|
+ }
|
|
|
+ return true
|
|
|
+}
|
|
|
+
|
|
|
+func main() {
|
|
|
+ var n int
|
|
|
+ fmt.Print("Number of connections: ")
|
|
|
+ fmt.Scan(&n)
|
|
|
+ ds := newDisjointSet(n)
|
|
|
+ for i := 1; i <= n; i++ {
|
|
|
+ var x, y int
|
|
|
+ fmt.Printf("No.%d connection: ", i)
|
|
|
+ fmt.Scan(&x, &y)
|
|
|
+ ds.union(x, y)
|
|
|
+ }
|
|
|
+ fmt.Print("Check connection: ")
|
|
|
+ var x, y int
|
|
|
+ fmt.Scan(&x, &y)
|
|
|
+ fmt.Println(ds.isConnected(x, y))
|
|
|
+}
|