segment_tree.go 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. package main
  2. // SegmentTree ...
  3. type SegmentTree []int
  4. // NewSegmentTree ...
  5. func NewSegmentTree(n int) SegmentTree {
  6. var st SegmentTree = make([]int, 2*n)
  7. for i := n; i < 2*n; i++ {
  8. st[i] = 1
  9. }
  10. for i := n - 1; 0 < i; i-- {
  11. st[i] = st[2*i] + st[2*i+1]
  12. }
  13. return st
  14. }
  15. // Update ...
  16. func (st SegmentTree) Update(i, val int) {
  17. i += len(st) / 2
  18. st[i] = val
  19. for i > 1 {
  20. i /= 2
  21. st[i] = st[2*i] + st[2*i+1]
  22. }
  23. }
  24. // Search ...
  25. func (st SegmentTree) Search(beg, val int) (end int) {
  26. val %= st[1] // Deal with the loop
  27. if val == 0 { // If val is zero, set val to st[1]
  28. val = st[1]
  29. }
  30. n := len(st) / 2
  31. slotCnt := st.Count(beg, n)
  32. if slotCnt < val { // If slot cnt is not enough, start from the begining
  33. val -= slotCnt
  34. beg = 0
  35. }
  36. beg, end = beg+n, beg+n
  37. for st[end] < val {
  38. if end%2 == 1 { // If is right child, move to left child
  39. val -= st[end]
  40. end++
  41. }
  42. end /= 2
  43. } // Find the interval that contains end from bottom to top
  44. for end < n { // While end is not leaf
  45. lchild, rchild := end*2, end*2+1
  46. if st[lchild] < val {
  47. val -= st[lchild]
  48. end = rchild // Find in right tree
  49. } else { // st[lchild] == val
  50. end = lchild // Find in left tree
  51. }
  52. }
  53. return end - n
  54. }
  55. // Count ...
  56. func (st SegmentTree) Count(left, right int) (sum int) {
  57. n := len(st) / 2
  58. left, right = left+n, right+n
  59. for left < right {
  60. if left%2 == 1 {
  61. sum += st[left]
  62. left++
  63. }
  64. if right%2 == 1 { // right is not included, so --
  65. right--
  66. sum += st[right]
  67. }
  68. left, right = left/2, right/2
  69. }
  70. return
  71. }