SkipList.java 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import java.util.ArrayList;
  2. import java.util.HashSet;
  3. import java.util.List;
  4. import java.util.Set;
  5. import java.util.Stack;
  6. import java.util.concurrent.*;
  7. import java.util.concurrent.atomic.AtomicInteger;
  8. import java.util.concurrent.atomic.AtomicReference;
  9. public class SkipList<T extends Comparable<T>> {
  10. private final ThreadLocalRandom random;
  11. private final int maxLevel;
  12. private final AtomicInteger curLevel;
  13. private final double p;
  14. private final ListNode<T> head;
  15. public SkipList(int maxLevel, double p) {
  16. this.maxLevel = maxLevel;
  17. this.p = p;
  18. this.random = ThreadLocalRandom.current();
  19. this.curLevel = new AtomicInteger(0);
  20. this.head = new ListNode<T>(null, maxLevel);
  21. }
  22. public void add(T key) {
  23. insert(key);
  24. }
  25. public void insert(T key) {
  26. Stack<ListNode<T>> update = new Stack<>();
  27. ListNode<T> curr = head;
  28. for (int level = curLevel.get(); level >= 0; level--) {
  29. while (curr.canMoveForward(key, level)) {
  30. curr = curr.forward.get(level).get();
  31. }
  32. update.add(curr);
  33. }
  34. ListNode<T> next = curr.forward.get(0).get();
  35. if (next != null && key.compareTo(next.key) == 0) {
  36. // Update value
  37. return;
  38. }
  39. ListNode<T> down = null;
  40. int insertLevel = randomLevel();
  41. for (int level = 0; level <= insertLevel; level++) {
  42. curr = update.isEmpty() ? head : update.pop();
  43. ListNode<T> up = insertAfter(key, curr, level);
  44. if (level > 0) {
  45. up.forward.get(level - 1).set(down);
  46. }
  47. down = up;
  48. }
  49. while (true) {
  50. int expectedLevel = curLevel.get();
  51. if (expectedLevel >= insertLevel || curLevel.compareAndSet(expectedLevel, insertLevel)) {
  52. break;
  53. }
  54. }
  55. }
  56. public boolean contains(T key) {
  57. ListNode<T> curr = head;
  58. for (int level = curLevel.get(); level >= 0; level--) {
  59. while (curr.canMoveForward(key, level)) {
  60. curr = curr.forward.get(level).get();
  61. }
  62. }
  63. curr = curr.forward.get(0).get();
  64. return curr != null && key.compareTo(curr.key) == 0;
  65. }
  66. @Override
  67. public String toString() {
  68. StringBuilder sb = new StringBuilder();
  69. sb.append("current level: ").append(curLevel.get()).append("\n");
  70. for (int i = curLevel.get(); i >= 0; i--) {
  71. ListNode<T> curr = head.forward.get(i).get();
  72. while (curr != null) {
  73. sb.append(curr.key.toString()).append(" ");
  74. curr = curr.forward.get(i).get();
  75. }
  76. sb.append("\n");
  77. }
  78. return sb.toString();
  79. }
  80. private ListNode<T> insertAfter(T key, ListNode<T> before, int level) {
  81. ListNode<T> after = new ListNode<>(key, level);
  82. while (true) {
  83. while (before.canMoveForward(key, level)) {
  84. before = before.forward.get(level).get();
  85. }
  86. ListNode<T> next = before.forward.get(level).get();
  87. after.forward.get(level).set(next);
  88. if (before.forward.get(level).compareAndSet(next, after)) {
  89. break;
  90. }
  91. }
  92. return after;
  93. }
  94. private int randomLevel() {
  95. int level = 0;
  96. while (level < maxLevel) {
  97. if (random.nextDouble() < p) {
  98. level++;
  99. } else {
  100. break;
  101. }
  102. }
  103. return level;
  104. }
  105. public static void main(String[] args) throws InterruptedException {
  106. SkipList<Integer> skipList = new SkipList<>(20, 0.2);
  107. // Set<Integer> skipList = new ConcurrentSkipListSet<>();
  108. int threads = 1000;
  109. int loop = 100;
  110. ExecutorService executor = Executors.newCachedThreadPool();
  111. CountDownLatch cd = new CountDownLatch(threads);
  112. for (int i = 0; i < threads; i++) {
  113. final int initJ = i;
  114. executor.submit(() -> {
  115. for (int j = initJ; j < threads * loop; j += threads) {
  116. skipList.add(j);
  117. }
  118. cd.countDown();
  119. });
  120. }
  121. cd.await();
  122. executor.shutdown();
  123. for (int i = 0; i < threads * loop; i++) {
  124. if (!skipList.contains(i)) {
  125. System.out.println("Concurrent error: " + i);
  126. }
  127. }
  128. }
  129. public static class ListNode<T extends Comparable<T>> {
  130. private List<AtomicReference<ListNode<T>>> forward;
  131. private T key;
  132. public ListNode(T key, int level) {
  133. this.key = key;
  134. this.forward = new ArrayList<>(level + 1);
  135. for (int i = 0; i <= level; i++) {
  136. this.forward.add(new AtomicReference<>());
  137. }
  138. }
  139. public boolean canMoveForward(T key, int level) {
  140. return this.forward.get(level).get() != null &&
  141. key.compareTo(this.forward.get(level).get().key) > 0;
  142. }
  143. @Override
  144. public String toString() {
  145. return "ListNode{" +
  146. "forward=" + forward +
  147. ", key=" + key +
  148. '}';
  149. }
  150. }
  151. }