unionfind.go 1022 B

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. package main
  2. import "fmt"
  3. type disjointSet struct {
  4. par []int
  5. rank []int
  6. }
  7. func newDisjointSet(n int) disjointSet {
  8. ds := disjointSet{make([]int, n), make([]int, n)}
  9. for i := range ds.par {
  10. ds.par[i] = i
  11. }
  12. return ds
  13. }
  14. func (ds disjointSet) find(i int) int {
  15. if ds.par[i] == i {
  16. return i
  17. }
  18. pa := ds.find(ds.par[i])
  19. ds.par[i] = pa
  20. return pa
  21. }
  22. func (ds disjointSet) isConnected(i, j int) bool {
  23. return ds.find(i) == ds.find(j)
  24. }
  25. func (ds disjointSet) union(i, j int) bool {
  26. i, j = ds.find(i), ds.find(j)
  27. if i == j {
  28. return false
  29. }
  30. if ds.rank[i] < ds.rank[j] {
  31. ds.par[i] = j
  32. } else {
  33. if ds.rank[i] == ds.rank[j] {
  34. ds.rank[i]++
  35. }
  36. ds.par[j] = i
  37. }
  38. return true
  39. }
  40. func main() {
  41. var n int
  42. fmt.Print("Number of connections: ")
  43. fmt.Scan(&n)
  44. ds := newDisjointSet(n)
  45. for i := 1; i <= n; i++ {
  46. var x, y int
  47. fmt.Printf("No.%d connection: ", i)
  48. fmt.Scan(&x, &y)
  49. ds.union(x, y)
  50. }
  51. fmt.Print("Check connection: ")
  52. var x, y int
  53. fmt.Scan(&x, &y)
  54. fmt.Println(ds.isConnected(x, y))
  55. }