dengxinyi 6 years ago
parent
commit
c8430174a9
1 changed files with 62 additions and 0 deletions
  1. 62 0
      oj/algorithms/unionfind/unionfind.go

+ 62 - 0
oj/algorithms/unionfind/unionfind.go

@@ -0,0 +1,62 @@
+package main
+
+import "fmt"
+
+type disjointSet struct {
+	par  []int
+	rank []int
+}
+
+func newDisjointSet(n int) disjointSet {
+	ds := disjointSet{make([]int, n), make([]int, n)}
+	for i := range ds.par {
+		ds.par[i] = i
+	}
+	return ds
+}
+
+func (ds disjointSet) find(i int) int {
+	if ds.par[i] == i {
+		return i
+	}
+	pa := ds.find(ds.par[i])
+	ds.par[i] = pa
+	return pa
+}
+
+func (ds disjointSet) isConnected(i, j int) bool {
+	return ds.find(i) == ds.find(j)
+}
+
+func (ds disjointSet) union(i, j int) bool {
+	i, j = ds.find(i), ds.find(j)
+	if i == j {
+		return false
+	}
+	if ds.rank[i] < ds.rank[j] {
+		ds.par[i] = j
+	} else {
+		if ds.rank[i] == ds.rank[j] {
+			ds.rank[i]++
+		}
+		ds.par[j] = i
+	}
+	return true
+}
+
+func main() {
+	var n int
+	fmt.Print("Number of connections: ")
+	fmt.Scan(&n)
+	ds := newDisjointSet(n)
+	for i := 1; i <= n; i++ {
+		var x, y int
+		fmt.Printf("No.%d connection: ", i)
+		fmt.Scan(&x, &y)
+		ds.union(x, y)
+	}
+	fmt.Print("Check connection: ")
+	var x, y int
+	fmt.Scan(&x, &y)
+	fmt.Println(ds.isConnected(x, y))
+}