Java ConcurrentSkipListMap

线程安全的跳表实现——Java ConcurrentSkipListMap 原理解析 - 知乎 (zhihu.com)

前言

LevelDB、RockDB以及Pebble等高性能键值对存储引擎都采用SkipList来作为自己的内存数据结构。相比使用红黑树,SkipList具有以下优点

  1. 数据结构比红黑树更简单
  2. 具备更好的并发优势。红黑树在插入式可能涉及到整颗树的rebalance,而SkipList可以仅在某个局部进行操作,要锁定的节点更少,从而实现更好的性能。本文基于JAVA JUC中实现的ConcurrentSkipList,来介绍一种线程安全的SkipList的无锁实现

跳表基础

为了更好地了解 ConcurrentSkipListMap 的实现,我们将它的核心逻辑抽离出来,先理解它的跳表部分实现,再看它在并发访问上做了哪些优化。

此处示例代码结构尽可能保持与 java.util.concurrent.ConcurrentSkipListMap 一致,仅对部分代码命名做出修改,以便后续对并发部分进行讲解,因此可能会包含一些不符合常见代码规范的部分。

ConcurrentSkipListMap 整体的数据结构如下图所示:

ConcurrentSkipList中定义了以下三种节点

  1. 数据节点Node:实际保存数据的节点,包含Map的Key/Value,并使用next指针串成一个链表节点
  2. 索引节点Index:用于快速查找,内部包含数据节点的应用。Right指针指向下一个索引节点,down指向下一层的索引节点
  3. 头结点HeadIndex:Index类的子类,相比Index多了一个Level字段,标识当前的SkipList的层数,具体代码定义如下
// 此处简化了一些,要求 Key 实现了 Comparable 接口。
// 实际的 ConcurrentSkipListMap 允许传入自定义的 Key Comparator。
@NotThreadSafe
public class SkipListMap<K extends Comparable<K>, V> {  
    private static class Node<K, V> {  
        final K key;  
        private Object value;  
        private Node<K, V> next;  
        ...
    }  

    private static class Index<K, V> {  
        Node<K, V> node;  
        Index<K, V> down;  
        Index<K, V> right;  
        ...
    }  

    private static class HeadIndex<K, V> extends Index<K, V> {  
        int level;
        ...  
    }
}

初始化

在初始化时,会初始化跳表的头节点,形成如下图所示的数据结构

对应代码如下

private HeadIndex<K, V> head;
private final static Object BASE_HEADER = new Object();  

/**
 * 初始化数据节点:Key=null, Value=BASE_HEADER,初始化 level=1
 * 其他指针全部初始化为 null
 */
public SkipListMap() {  
    head = new HeadIndex<K, V>(new Node<K, V>(null, BASE_HEADER, null),  
        null, null, 1);  
}

SkipList的头结点不直接存储数据,仅作为跳表的头使用。

插入

数据的插入我们以最上面已经插入过 7 条数据的跳表为例。现有的数据 Key 分别为 {1, 3, 5, 7, 9, 11, 13},我们要插入一条 Key 为 10 的记录。

寻找插入位置,插入数据节点

在寻找插入位时,需要先从head节点开始,找到离待插入位置最近的索引,逻辑如下

  1. 向右查找:如果当前节点的right节点不为空且right节点的key小于插入数据的key则向右查找
  2. 向下查找:如果不能向右查找(right位null或者right.node.key大于待插入的key),则向下或向右移动一层
  3. 当无法向下或者向右移动时,返回该索引对应数据节点Node

如上图所示,在访问到第3部的Index节点时,有由于它的右侧节点的值为11>10,无法向右侧移动;同时它的down节点为null,无法向下移动,因此返回它对应的数据节点,即key=7的节点(图中标红部分)

注意:因为Index节点内部包含的是Node节点的引用,而不是使用down指针指向Node节点,因此在第3步对应的节点的down指针为null。

对应代码如下

public Node<K, V> findPredecessor(K key) {  
    if (key == null) {  
        throw new NullPointerException();  
    }  
    Index<K, V> current = head, right = current.right;  
    while (true) {  
        // 向右查找
        if (right != null) {  
            Node<K, V> rightNode = right.node;  
            // right.node.key 小于待插入key,继续向右寻找  
            if (key.compareTo(rightNode.key) > 0) {  
                current = right;  
                right = right.right;  
                continue;            
            }  
        }  
        // 不能继续向右寻找,向下移动一层
        if (current.down == null) {  
            // 不能继续向下寻找,说明到达第一层,可以直接返回  
            return current.node;  
        }  
        current = current.down;  
        right = current.right;  
    }  
}

利用 Index 查找到 Key = 7 的节点数据后,继续在底层数据节点使用 next 指针向右寻找,直到找到实际插入位置的前驱节点,即 Key=9 的数据,并将其插入到链表中:

这部分代码如下

public void put(K key, V value) {  
    Node<K, V> nodeToInsert = null; 
    // 利用索引查找前驱节点 
    for (Node<K,V> prev = findPredecessor(key), next = prev.next;;) {  
        int cmp = Integer.MIN_VALUE;
        // 在数据节点链表中移动到待插入位置
        while (next != null && (cmp = key.compareTo(next.key)) > 0) {  
            prev = next;  
            next = next.next;  
        }  
        // 遇到相等的值,将其覆盖  
        if (cmp == 0) {  
            next.value = value;  
            return;        
        }  
        // 插入节点
        nodeToInsert = new Node<>(key, value);  
        nodeToInsert.next = next;  
        prev.next = nodeToInsert;  
        break;    
    }
    ...
}

构建索引按照概率提升Level

在这里的SkipList中,SkipList提升一层概率是1/2。java这里使用了一个比较聪明的办法:随机生成一恶搞32位的整形数,从这个数最低位开始向最高位遍历,如果这一位位1,则将SkipList提升一层。这样,这个SkipList的Level提升概率为1/2,且最大层数位31.

在这个过程中,随机数只生成了一次,并且计算操作都是位运算,使得这个过程耗时较短。但这种做法也丧失了一定灵活性,比如无法配置SkipList的最大层数,并且1/2在实践中并不一定是是最佳的概率。比如一些键值数据库的内存SkipList就是使用自然对数1/e作为提升概率,而非1/2.

提升概率低可能会使SkipList总体层数较少,影响查找性能;提升概率过高需要移动更多的指针节点,影响SkipList的写入性能(特别是在并发场景下,更多的指针自动意味着更多的冲突)

此处插入节点 level 的计算逻辑如下:

public void put(K key, V value) {  
    ...

    // Java 的 Random 对象为了保证线程安全,使用了 CAS 等手段,可能会影响一定的性能,在实际的 
    // ConcurrentSkipListMap 中使用了线程独立的随机数生成 ThreadLocalRandom 来生成随机数
    int rnd = random.nextInt(Integer.MAX_VALUE);  
    // 0x80000001(最高位和最低位为1),此处保证生成的数必须是正偶数
    if ((rnd & 0x80000001) != 0) {  
        return;  
    }  
    int level = 1;  
    while (((rnd >>>= 1) & 1) != 0) {  
        ++level;  
    }
    ...
}

计算完成新的level值后情况分为俩种

情况一

levelhead.level,新的 level 小于 SkipList 的层数,只需要为插入节点建立 levelIndex 索引即可,如下图所示,假设新计算的 level 值为 1:

情况二

level > head.level,新的 level 大于等于 SkipList 的层数,则需要更新 SkipList 整体的层数,具体操作为:

  • 为待插入节点建立 levelIndex 索引,并将这些索引记录下来;
  • head 节点上层 level - head.levelIndex 索引,并将这些节点的 right 指针指向上一步建立的在同一层的 Index 节点(并且限制了每次只能将 SkipList 提升一层);
  • 更新头节点;
  • 如下图所示,假设新的 level 大于等于 2,则新的 leve 值限制为 3:

需要注意的是,虽然旧的head节点实际对应仍然是HeadIndex对象,但在之后的使用中都只会将其当成普通Index对象进行处理,因此此处将看做普通的Index节点.

这部分对应的代码如下:

public void put(K key, V value) {  
    ...

    int maxLevel = head.level;  
    Index<K, V> indexToInsert = null;  
    // 第一种情况
    if (level <= maxLevel) {  
        for (int i = 0; i < level; ++i) {  
            // 注意第二个参数,遍历 level 次,每次都向上创建一个 Index
            // new 的 Index 的 down 指针指向上一个 Index
            indexToInsert = new Index<>(nodeToInsert, indexToInsert, null);  
        }  
    } else { // 第二种情况
        // 最多向上扩展一层  
        level = maxLevel + 1;  
        Index<K, V>[] indexes = new Index[level+1];  
        for (int i = 1; i <= level; ++i) {  
            indexToInsert = new Index<>(nodeToInsert, indexToInsert, null);  
            // 记录新创建的Index,下标为该Index对应的层数
            indexes[i] = indexToInsert;  
        }  
        HeadIndex<K, V> newHead = head;  
        Node<K, V> oldBase = head.node;  
        int oldLevel = head.level;  
        for (int j = oldLevel+1; j <= level; ++j) {  
            // head.right 指向该层新创建的 Index
            newHead = new HeadIndex<>(oldBase, newHead, indexes[j], j);  
        }  
        // 更新head指针
        head = newHead;  
        // 更新level为旧head对应的层数,用于后面right指针的构建
        indexToInsert = indexes[level = oldLevel];
    }
    ...
}

连接剩余的right指针

再上面索引建立完成之后,可以看到新节点Key=10以及它左侧索引的right指针并没有指向正确的位置.

如上图所示,因此我们还需要将索引指向正确的位置.具体做法就是从head开始,遍历每一层,找到新插入的Index节点的前驱节点,然后更新它的right指针,如下图:

代码如下

public void put(K key, V value) { 
    ...
        // 注意 level > head.level 的情况下
        // level 的值最后被更新为 oldHead.level
        indexToInsert = indexes[level = oldLevel];
    }
    ...
    // 要构建right指针的层数
    // 如果新节点level小于等于原来的level,则就在这个新节点的最高level开始构建;
    // 如果新节点level大于原本level的情况,则从旧head那一层开始插入
    // 即:insertionLevel = min(oldHead.level, newNode.level)
    int insertionLevel = level;  
    // 从 head.level 开始逐渐向下层遍历,找到需要构建right指针的位置
    int currentLevel = head.level;  
    for (Index<K, V> currentIndex = head, rightIndex = currentIndex.right, newIndex = indexToInsert;;) {  
        // 向右移动到本次插入的节点,寻找新索引的前驱节点
        if (rightIndex != null) {  
            if (key.compareTo(rightIndex.node.key) > 0) {  
                currentIndex = rightIndex;  
                rightIndex = rightIndex.right;  
                continue;        
            }  
        }  
        // 到达要重建right指针的层数
        if (currentLevel == insertionLevel) {  
            currentIndex.right = newIndex;  
            newIndex.right = rightIndex;  
            // 这一层重建完成,将要重建的层数减一
            insertionLevel--;
            newIndex = newIndex.down;
            // 重建完成,直接退出  
            if (insertionLevel == 0) {  
                return;  
            }  
        }  
        // 向下移动一层
        currentLevel--;
        currentIndex = currentIndex.down;  
        rightIndex = currentIndex.right;  
    }
}

连接剩余的right指针

在上面索引建立完成之后,可以看到新节点(Key=10)以及它左侧索引的 right 指针并没有指向正确的位置。

如图所示,因此我们还需要指向正确的位置。具体的做法就是从head开始,遍历每一层,找到新插入的Index节点的前驱节点,然后更新它的right指针

代码如下

public void put(K key, V value) { 
    ...
        // 注意 level > head.level 的情况下
        // level 的值最后被更新为 oldHead.level
        indexToInsert = indexes[level = oldLevel];
    }
    ...
    // 要构建right指针的层数
    // 如果新节点level小于等于原来的level,则就在这个新节点的最高level开始构建;
    // 如果新节点level大于原本level的情况,则从旧head那一层开始插入
    // 即:insertionLevel = min(oldHead.level, newNode.level)
    int insertionLevel = level;  
    // 从 head.level 开始逐渐向下层遍历,找到需要构建right指针的位置
    int currentLevel = head.level;  
    for (Index<K, V> currentIndex = head, rightIndex = currentIndex.right, newIndex = indexToInsert;;) {  
        // 向右移动到本次插入的节点,寻找新索引的前驱节点
        if (rightIndex != null) {  
            if (key.compareTo(rightIndex.node.key) > 0) {  
                currentIndex = rightIndex;  
                rightIndex = rightIndex.right;  
                continue;        
            }  
        }  
        // 到达要重建right指针的层数
        if (currentLevel == insertionLevel) {  
            currentIndex.right = newIndex;  
            newIndex.right = rightIndex;  
            // 这一层重建完成,将要重建的层数减一
            insertionLevel--;
            newIndex = newIndex.down;
            // 重建完成,直接退出  
            if (insertionLevel == 0) {  
                return;  
            }  
        }  
        // 向下移动一层
        currentLevel--;
        currentIndex = currentIndex.down;  
        rightIndex = currentIndex.right;  
    }
}

读取

在实现了 put() 方法后,get() 方法的实现就比较简单了。直接调用 findPredecessor() 方法找到离待查找节点最近的索引,之后遍历进行查找即可,代码如下:

public V get(K key) {  
    // 寻找前驱节点
    for (Node<K, V> prev = findPredecessor(key), next = prev.next;;) {  
        if (next == null) {  
            return null;  
        }  
        int compare = key.compareTo(next.key);  
         if (compare == 0) {  
            V result = (V) next.value;  
            return result;  
         }  
         // 对应 key 不存在,直接退出
         if (compare < 0) {  
             break;  
         }  
         next = next.next;  
    }  
    return null;  
}

删除

删除也比较类似。 首先找到待删除的数据节点,移除数据节点:

public V remove(K key) {  
    for (Node<K, V> prev = findPredecessor(key), next = prev.next;;) {  
        if (next == null) {  
            return null;  
        }  
        Object value = next.value;  
        int compare = key.compareTo(next.key);  
        // 对应的 Key 不存在
        if (compare < 0) {  
            return null;  
        }  
        // 移动到要删除的节点位置
        if (compare > 0) {  
            prev = next;  
            next = next.next;  
            continue;        
        }  
        // 将节点标记为删除状态,便于后续删除对应的索引  
        next.value = null;  
        // 删除 node 节点  
        prev.next = next.next;  
        // 删除该数据节点对应的索引
        cleanIndex(key);  
        if (head.right == null) {  
            reduceLevel();  
        }  
        return (V) value;  
    }  
}

删除数据节点后,需要删除数据节点对应的索引。这里 SkipListMap 的做法是从 head 节点开始,逐个遍历索引,如果某个索引对应数据的 value 为 null,则删除这个索引:

// 整体查找逻辑和findPredecessor()相同,只是添加了删除流程
// 在 ConcurrentSkipListMap 中实际上该逻辑就是由findPredecessor()方法实现的
// 这里为了作为区分单独实现了一个方法
private void cleanIndex(K key) {  
    for (Index<K, V> currentIndex = head, rightIndex = currentIndex.right;;) {  
        if (rightIndex != null) {  
            Node<K, V> rightNode = rightIndex.node;  
            if (rightNode.value == null) {  
                currentIndex.right = rightIndex.right;  
                rightIndex = currentIndex.right;  
                continue;            
            }  
            if (key.compareTo(rightNode.key) > 0) {  
                currentIndex = rightIndex;  
                rightIndex = rightIndex.right;  
                continue;            
            }  
        }  
        if (currentIndex.down == null) {  
            return;  
        }  
        currentIndex = currentIndex.down;  
        rightIndex = currentIndex.right;  
    }  
}

在删除索引后,如果最上层已经没有索引了,需要降低跳表的层数,即将 head 向下移动一层:

// ConcurrentSkipListMap 需要在最上面三层都为空的情况下,才会将整体 level 减少一层;
// 这是为了最大程度地减少层数降低可能带来的数据丢失,具体参考后面线程安全部分的讲解
// 这里为了保证一致也采用相同的设计
private void reduceLevel() {  
    HeadIndex<K, V> down;  
    HeadIndex<K, V> down2Level;  
    if (head.level > 3 &&  
            (down = (HeadIndex<K, V>) head.down) != null &&  
            (down2Level = (HeadIndex<K, V>) down.down) != null &&  
            down2Level.right == null &&  
            down.right == null &&  
            head.right == null) {  
        head = down;  
    }  
}

并发问题

对于上面的 SkipList,我们用 10 个线程并发写入 1 ~ 100 这 100 数字,最终得到的 SkipList 如下:

// 索引节点(只展示索引对应的Key)
level 5: 95 
level 4: 51 60 73 95 
level 3: 8 37 38 46 51 59 60 73 95 
level 2: 0 0 5 8 14 16 17 26 31 33 ...... 63 66 68 73 81 95 96 
level 1: 0 0 1 5 8 9 12 14 15 16 17 17 18 20 21 23 25 26 28 31 33 ...... 97 
// 数据节点(展示Key和value)
level 0: {0:0, 1:1, 2:2, 3:3, 4:4, 5:5}

可以看到最终的输出结果完全不符合预期。 发生这个问题的主要原因是 SkipListMap 的 put() 方法并不是原子的。每次 put() 操作都需要对 IndexNode 等多个节点的指针进行移动操作,在移动时,可能被其他线程介入,导致之前的写入被覆盖,我们以下面这段代码为例:

nodeToInsert = new Node<>(key, value);  
nodeToInsert.next = next;  
prev.next = nodeToInsert;

我们暂时不考虑 Java 的指令重排序以及 CPU 缓存与内存的延迟同步问题,假设有下面两个线程按如下顺序执行:

Thread 1Thread 2
nodeToInsert = new Node()
nodeToInsert = new Node()
nodeToInsert.next = next
nodeToInsert.next = next
prev.right=nodeToInsert
prev.right=nodeToInsert

可以看到前驱节点最终指向的是 Thread 2 创建的节点,而 Thread 1 的写入操作被覆盖了。除了以上并发写入的场景,在并发读写的场景下,读线程可能会读取到一些处于中间态的数据(比如调用了 remove() 方法但还没有完全删除的数据)。因此,我们上面实现的 SkipListMap 是线程不安全的。

悲观锁

悲观锁就是一种避免冲突的办法,我们同一时间只允许一个线程进入 put() 等方法,比如使用 Java 的 synchronized 关键字:

CompletableFuture.supplyAsync(() -> {  
    for (int j = 0; j < 100; j++) {  
        synchronized (Main.class) {  
            skipListMap.put(j, j);  
        }  
    }  
    return null;  
});

这种锁总是假设最坏的情况,认为每次自己拿数据的时候别人都会修改,所以共享资源每次只允许一个线程使用。但在使用上面这种悲观锁的情况下,同一时间只能有一个线程可以读取跳表的内容,会极大程度地影响跳表的性能。

当然,我们也可以使用读写锁来改善并发读的性能(比如 Java 的 ReentrantReadWriteLock),这种锁允许多个线程同时进行读操作,但在并发读写的情况下,读写操作还是会被其他的写操作阻塞。

还有一个最大的问题在于,调表这种数据结构在写入时修改存在一定的局部性。不同于红黑树等树状结构在插入时平衡操作可能涉及整个树,跳表在插入或删除时只涉及一小部分的指针变动。比如以下情况,我们向SkipList同事插入俩个节点:

在上图中,我们用不同颜色标识除了插入时可能受到影响的节点,可以看到这俩个节点在插入时对彼此没有影响,在这种情况下,跳表是可以并发写入的。这也是RocksDB、Pebble等数据库选用跳表作为内存数据库的主要原因,跳表具有更好的并发性能。如果使用悲观锁,将无法有效的利用跳表这种特性。

为了保证上面这种并发写入的特性,我们是否可以让每一个节点都持有一个悲观锁,同时只让一个线程对某个节点进行修改?

或许这个方案可行,但是为每个节点都维持一个锁会造成很大的开销。这样每次写操作都涉及多次加锁和解锁操作,而加锁和解锁涉及内核态和用户态之间的转换,会造成较大的性能损失。因此,再Java的ConcurrentSkipListMap以及RocksDB等数据库中,使用乐观锁来解决上面提到的问题。

乐观锁

乐观锁和悲观锁相反,它认为在大多数情况下都不会出现冲突,它在修改数据前会进行冲突的检测,如果发生冲突的话就不修改数据。假设我们要将一个值为 a 的变量修改为 b,乐观锁的做法如下:

  1. 获取并记录待修改的值 a
  2. 计算更新后的值 b
  3. 比较当前内存对应位置的值和 a 是否相等,如果相等则将其赋值为 b;否则不进行修改。

现代 CPU 几乎都提供了指令来实现以上操作,我们称为 CAS(Compare And Swap 或 Compare And Set)。CAS 操作是原子的,因此 CAS 是一种常见的实现乐观锁的方法。

让我们回到跳表的实现上,我们在修改跳表对应的指针时,就可以使用 CAS 算法,来保证自己的修改不会被其他线程篡改,比如在插入数据节点时,我们就可以使用下面的代码:

while(true) {
    // 寻找到待插入位置
    Node<K, V> prevNode = findPredecessor(key), nextNode = prevNode.next;
    // newNode 是新建的节点,不会存在冲突问题
    Node<K, V> newNode = new Node<>();
    newNode.next = nextNode;
    // 在插入时,如果prevNode.next还是nextNode,说明这个指针没被其他线程修改,不存在冲突
    // 可以将next指针设置为newNode
    if (CompareAndSet(prevNode.next, nextNode, newNode) {
        break;
    }
    // 如果修改失败,返回到for循环开头重试修改
}

可以看到,由于乐观锁并没有阻塞线程的执行,只是进行了冲突检测,因此在不存在冲突的情况下可以实现多个节点的鬓发写入,让跳表有更好的性能。

同时,使用这种方式时,写入操作不会阻塞跳表的读取操作,再写入比较频繁的场景下,对系统的读读性能的影响也比较小。

当然CAS实现乐观锁这种方式也存在一定的缺陷,比如:

  1. 如果数据冲突比较严重,现成可能会不断重试,陷入空转,占用大量CPU资源;
  2. ABA问题

总体来看,跳表数据这种数据结构具有较好的局部性,修改时只会涉及整个跳表中的一小部分数据,并发写入时发生冲突的概率较小,在数据量较大时尤其如此;而 ABA 问题不会影响跳表的功能正确性,因此跳表是一种非常适合使用 CAS 来实现无锁的数据结构。

无锁跳表实现

删除

在ConcurrentSkipListMap中,删除操作比较特殊,除了使用CAS修改指针以外,它还是用了一些其他机制来防止出错。假设我们有b、n和f三个数据节点,形成链表b->n->f->g,如果我们仅仅使用CAS操作删除节点时 ,可能会发生以下问题

场景一多线程并发删除

假设线程1要删除节点n,线程2要删除节点f(f的后继节点为g),它们会各自调用CAS方法,如下:

  • 线程 1:CAS(b.next, n, f)
  • 线程 2:CAS(n.next, f, g)

假设线程1先执行,线程2再执行,由于线程1的CAS操作并没有修改节点n的next指针,因此线程2的CAS操作也会成功,最终形成如下所试点链表

最终得到的链表为 b -> f -> g,线程 2 的修改操作失败了。

场景二:多线程并发写入

假设线程 1 要删除节点 n,线程 2 要在节点 n 后面插入一个新节点 g,它们会各自调用以下 CAS 操作:

线程 1:CAS(b.next, n, f)

线程 2:CAS(n.next, f, g),此处节点 g 在初始化时已经将 next 指针指向了 f

我们还是假设线程 1 先执行(实际上它们的执行顺序不影响实际结果),由于线程 1 没有修改节点 nnext 指针,因此线程 2 的 CAS 操作也能执行成功,形成如下链表:

最终形成链表遍历的结果为 b -> f,线程 2 插入的数据节点最终丢失了。

根据上面俩种异常情况,可以知道删除操作除了需要保证前驱节点修改的并发安全以外,也要防止这个过程中待删除节点的next指针被修改,才能保证删除的正确性。

为此ConcurrentSkipListMap采用插入maker节点的方式来防止其他线程修改待删除节点的next指针。ConcurrentSkipListMap的数据节点删除主要分为三步。

  1. 使用 CAS 操作将 n 的 value 设置为 null,这样其他线程在使用 get() 方法获取节点 n 的值时,获取到的结果也为 null,调用方会认为这个值已经被删除;但是此时其他的写入或者删除操作可能会继续修改 nnext 指针;

使用 CAS 操作使节点 nnext 指针指向一个新节点,我们称这个节点为 marker 节点。这个节点的 key 是 null,value 指向自身。存在 marker 节点时,其他任何一个线程想要通过 CAS 操作修改节点 nnext 指针时,都不可能成功(具体示例参考扩展阅读),从而防止在并发写入的情况下出现删除错误:

最后使用 CAS 操作将 bnext 指针指向 f,之后 nmarker 节点会被 GC 自动回收。

由于在删除的时候我们插入了一个 marker 节点,这个节点可能会影响其他节点的遍历、插入等操作。比如此时另一个线程想在节点 n 后面插入一个节点,它会调用 CompareAndSet(n.next, f, newNode),但由于此时 n.next 的值为 marker 节点,就会 CAS 失败导致重试;并且由于 marker 节点的 key 为 null,它还会影响其他线程的遍历过程。

为了避免这些问题,我们可以让其他线程帮忙删除对应的节点。当其他线程在遍历数据节点的过程中发现某个数据节点的 value 为 null,它就会帮忙执行删除流程。比如在 get() 方法中就存在这样一段代码:

private V doGet(Object key) {  
    Comparator<? super K> cmp = comparator;  
    outer: for (;;) {  
        for (Node<K,V> b = findPredecessor(key, cmp), n = b.next;;) {  
            Object v; int c;  
            ...
            // 查找节点的过程中发现节点 n 的 value 为 null
            // 帮助执行删除流程
            if ((v = n.value) == null) {    // n is deleted  
                n.helpDelete(b, f);  
                break;            
            }  
            ...
        }  
    }  
    return null;  
}

helpDelete() 方法的代码如下:

// helpDelete 是数据节点内部的实现方法,这里的this就是上面的节点 n
static final class Node<K,V> {
    ...
    void helpDelete(Node<K,V> b, Node<K,V> f) {  
        // 在多线程情况下节点n的前驱或者后继节点可能又被修改,需要再次确认
        if (f == this.next && this == b.next) {  
            // 如果后继节点的value不等于自身,说明它不是marker节点
            if (f == null || f.value != f) 
                // 帮忙插入marker节点
                casNext(f, new Node<K,V>(f));  
            else       
                // marker节点已经插入,帮忙删除节点n     
                b.casNext(this, f.next);  
        }  
    }
    ...
}

理解上面的流程后,ConcurrentSkipListMap 的 remove() 方法实现就很简单了,整体方法如下:

final V doRemove(Object key, Object value) {  
    if (key == null)  
        throw new NullPointerException();  
    Comparator<? super K> cmp = comparator;  
    // 两层循环
    // "break;" 表示退出内部循环,进行重试
    // "break outer;" 表示退出外层循环,结束方法 
    outer: for (;;) {  
        for (Node<K,V> b = findPredecessor(key, cmp), n = b.next;;) {  
            Object v; int c; 
            // 移动到链表结尾,直接退出 
            if (n == null)  
                break outer;  
            Node<K,V> f = n.next;  
            // 再次检查,如果b.next被修改过,就进行重试
            if (n != b.next)                    
                break;  
            // 遇到被删除节点,帮忙删除,然后进行重试
            if ((v = n.value) == null) {        
                n.helpDelete(b, f);  
                break;            
            }  
            // 检查前驱节点是否被删除,如果前驱节点被删除就进行重试
            if (b.value == null || v == n)
                break;  
            // key < n.key,已经越过要删除的节点,直接退出
            if ((c = cpr(cmp, key, n.key)) < 0)  
                break outer;  
            // key > n.key,还没有到达删除位置,继续向右移动
            if (c > 0) {  
                b = n;  
                n = f;  
                continue;            
            }  
            // remove 方法中value参数恒为 null,此处可忽略 
            // ConcurrentSkipListMap 的 doRemove 方法还支持传值
            // 节点的key、value和参数的key、value都相同才进行删除
            // 主要是用于一些其他的map相关操作
            if (value != null && !value.equals(v))  
                break outer;  
            // 到达待删除节点,执行删除的第一步:将节点n的值设置为null作为标记
            if (!n.casValue(v, null))  
                break;  
            // 执行删除的第二和第三步,插入marker节点,并将b的next指针指向f
            if (!n.appendMarker(f) || !b.casNext(n, f))  
                // 如果失败了,使用 findNode 遍历链表,重新删除数据节点和对应的索引
                findNode(key);                  // retry via findNode  
            else {  
                // 使用 findPredecessor 方法清除索引
                findPredecessor(key, cmp);      // clean index  
                // 最顶层没有数据,尝试减少层数
                if (head.right == null)  
                    tryReduceLevel();  
            }  
            @SuppressWarnings("unchecked") V vv = (V)v;  
            return vv;  
        }  
    }  
    return null;  
}

在上面的删除过程中,还需要介绍一下 findNode()findPredecessor()tryReduceLevel() 方法的实现。

由于 findNode() 方法也调用了 findPredecessor() 方法,我们先介绍 findPredecessor() 方法。它的作用就是遍历索引,查找离 key 前驱索引对应的数据节点,并在遍历的过程中删除被标记删除节点的索引。实际功能和我们基础链表中的 cleanIndex() 方法相同,如下:

private Node<K,V> findPredecessor(Object key, Comparator<? super K> cmp) {  
    if (key == null)  
        throw new NullPointerException(); // don't postpone errors  
    for (;;) {  
        // 和基础链表相同,从上层逐渐向下查找到离key最近的索引对应的数据节点
        for (Index<K,V> q = head, r = q.right, d;;) {  
            if (r != null) {  
                Node<K,V> n = r.node;  
                K k = n.key;  
                // 如果这个索引对应的数据节点被删除
                if (n.value == null) {  
                    // 使用 CAS 操作将索引 r 移除,即q.next = r.next
                    // 移除失败的话就进行重试
                    if (!q.unlink(r))  
                        break;           // restart 
                    // 移除成功,更新right索引位置,继续查找数据节点 
                    r = q.right;         // reread r  
                    continue;  
                }  
                // 后面为查找流程,和基础链表流程相同
                if (cpr(cmp, key, k) > 0) {  
                    q = r;  
                    r = r.right;  
                    continue;                
                }  
            }  
            if ((d = q.down) == null)  
                return q.node;  
            q = d;  
            r = d.right;  
        }  
    }  
}

findNode() 方法的主要效果是遍历跳表,查找某个 key 对应的数据节点。并且在这个过程中将标记被删除的数据节点删除:

private Node<K,V> findNode(Object key) {  
    if (key == null)  
        throw new NullPointerException(); // don't postpone errors  
    Comparator<? super K> cmp = comparator;  
    outer: for (;;) {  
        // 遍历节点
        for (Node<K,V> b = findPredecessor(key, cmp), n = b.next;;) {  
            Object v; int c;  
            if (n == null)  
                break outer;  
            Node<K,V> f = n.next;  
            if (n != b.next)                // inconsistent read  
                break;  
            // 如果某个数据节点被标记为删除,则将其删除
            if ((v = n.value) == null) {    // n is deleted  
                n.helpDelete(b, f);  
                break;            
            }  
            // 前驱节点被删除,则进行重试
            if (b.value == null || v == n)  // b is deleted  
                break;  
            // 它还实现了值的查找功能,提供给 Map 中的其他方法调用,
            // 但 `put()` 方法主要是用它来删除所有已经被删除的节点
            if ((c = cpr(cmp, key, n.key)) == 0)  
                return n;  
            if (c < 0)  
                break outer;  
            b = n;  
            n = f;  
        }  
    }  
    return null;  
}

然后是 tryReduceLevel 方法。减少跳表层数这个操作即使使用了 CAS 操作,也是非线程安全的。如果 head 节点在向下一层的过程中,其他线程向跳表插入了新的节点,并且在最上层构建了索引,在 head 向下移动之后这个索引就会丢失。虽然这样不会影响数据的正确性,但可能会影响跳表整体的性能。因此,为了降低这种问题发生的概率,ConcurrentSkipListMap 在跳表最上面 3 层都为空的情况下,才会尝试将整体层数减少一层。

private void tryReduceLevel() {  
    HeadIndex<K,V> h = head;  
    HeadIndex<K,V> d;  
    HeadIndex<K,V> e;  
    if (h.level > 3 &&  
        (d = (HeadIndex<K,V>)h.down) != null &&  
        (e = (HeadIndex<K,V>)d.down) != null &&  
        e.right == null &&  
        d.right == null &&  
        h.right == null &&  
        // 使用CAS操作将跳表整体减少一层
        casHead(h, d) && // try to set  
        // 减少一层之后发现这一层已经存在数据
        h.right != null) // recheck  
        // 恢复到开始状态
        casHead(d, h);   // try to backout  
}

插入

ConcurrentSkipListMap 的插入也和我们实现的基础链表一样,插入操作可以分为以下三步: 1. 遍历寻找插入位置,插入数据节点; 2. 按照概率提升层数,为新的数据节点构建索引; 3. 连接剩余的 right 指针。

插入数据节点

private V doPut(K key, V value, boolean onlyIfAbsent) {  
    Node<K,V> z;             // added node  
    if (key == null)  
        throw new NullPointerException();  
    Comparator<? super K> cmp = comparator;  
    // 两层循环,分别用来重试和结束操作
    outer: for (;;) {  
        // 使用findPredecessor查找前驱节点,顺便删除需要被删除的索引
        for (Node<K,V> b = findPredecessor(key, cmp), n = b.next;;) {  
            if (n != null) {  
                Object v; int c;  
                Node<K,V> f = n.next;  
                if (n != b.next)               // inconsistent read  
                    break;  
                // 遍历过程中遇到已经被删除的节点,则帮助插入marker节点或删除该数据节点
                if ((v = n.value) == null) {   // n is deleted  
                    n.helpDelete(b, f);  
                    break;                
                }
                // 前驱节点已经被删除,进行重试  
                if (b.value == null || v == n) // b is deleted  
                    break;  
                // 还未移动到待插入位置,继续向右移动
                if ((c = cpr(cmp, key, n.key)) > 0) {  
                    b = n;  
                    n = f;  
                    continue;                
                }  
                // 存在重复值,仅将其值覆盖
                if (c == 0) {  
                    if (onlyIfAbsent || n.casValue(v, value)) {  
                        @SuppressWarnings("unchecked") V vv = (V)v;  
                        return vv;  
                    }  
                    // 如果覆盖值失败,进行重试
                    break; // restart if lost race to replace value  
                }  
                // else c < 0; fall through  
            }
            // 创建新数据节点
            z = new Node<K,V>(key, value, n);  
            // 将其插入到节点b和n之间,如果插入失败就进行重试
            if (!b.casNext(n, z))  
                break;         // restart if lost race to append to b  
            break outer;
        } // 内层for循环结束
    } // 外层for循环结束
    ...
    return null;
}

可以看到插入数据节点的操作和基础链表基本相同,只是涉及修改操作时使用了 CAS 操作,并且在遍历查找的过程中加入了一致性检查和帮助删除的操作。

一致性检查指插入之前需要检查对应的前驱和后继节点的状态是否符合预期,包括:

  • 插入之前前驱节点 bnext 指针必须指向后继节点 n
  • 前驱节点 b 必须没有被删除。

如果数据不一致可能导致新插入的数据丢失

构建索引

构建索引的过程和普通 SkipList 基本类似,只是将修改操作换成了 CAS,并添加了一部分一致性检测,读者可以自行阅读以下代码:

// 使用 ThreadLocalRandom 提升随机数性能
int rnd = ThreadLocalRandom.nextSecondarySeed();
if ((rnd & 0x80000001) == 0) { // test highest and lowest bits  
    int level = 1, max;  
    while (((rnd >>>= 1) & 1) != 0)  
        ++level;  
    Index<K,V> idx = null;  
    HeadIndex<K,V> h = head;  
    if (level <= (max = h.level)) {  
        for (int i = 1; i <= level; ++i)  
            idx = new Index<K,V>(z, idx, null);  
    }  
    else { // try to grow by one level  
        level = max + 1; // hold in array and later pick the one to use  
        @SuppressWarnings("unchecked")Index<K,V>[] idxs =  
            (Index<K,V>[])new Index<?,?>[level+1];  
        for (int i = 1; i <= level; ++i)  
            idxs[i] = idx = new Index<K,V>(z, idx, null);  
        for (;;) {  
            h = head;  
            int oldLevel = h.level;  
            // 一致性检测,其他节点的插入导致SKipList层数已经增长过了
            // 则放弃更新头节点
            if (level <= oldLevel) // lost race to add level  
                break;  
            HeadIndex<K,V> newh = h;  
            Node<K,V> oldbase = h.node;  
            for (int j = oldLevel+1; j <= level; ++j)  
                newh = new HeadIndex<K,V>(oldbase, newh, idxs[j], j);  
            if (casHead(h, newh)) {  
                h = newh;  
                idx = idxs[level = oldLevel];  
                break;            
            }  
        }  
    }
    ...
}

连接剩余的right指针

和基础 SkipList 实现一样,这里需要从 head 节点所在的层数开始,遍历每一层,查找新插入节点的前驱节点。在并发条件下,头节点 head 可能随时发生变化,因此需要用局部变量 h 保存本次插入时对应的头结点,从这个节点开始遍历(在上一步构建索引的时候已将对应值更新到局部变量 h

// 从新插入节点的最高层开始重建 right 指针
splice: for (int insertionLevel = level;;) {  
    // 头结点的层数
    int j = h.level;  
    // t 为本次新插入的最上层索引
    for (Index<K,V> q = h, r = q.right, t = idx;;) {  
        // 遍历到跳表的末尾
        if (q == null || t == null)  
            break splice;  
        if (r != null) {  
            Node<K,V> n = r.node;  
            // compare before deletion check avoids needing recheck  
            int c = cpr(cmp, key, n.key);  
            // 删除检查,如果某个节点已经被标记为删除,则帮助删除
            if (n.value == null) {  
                // 删除失败,重试
                if (!q.unlink(r))  
                    break;  
                // 继续遍历
                r = q.right;  
                continue;            
            }  
            // 遍历向右移动
            if (c > 0) {  
                q = r;  
                r = r.right;  
                continue;           
             }  
        }
    ...
}

逐层向下遍历,直到到达 insertionLevel 时,就可以开始连接 right 指针了:

// 到达要连接right指针的层数
if (j == insertionLevel) {  
    // 连接right指针,失败就进行重试
    if (!q.link(r, t))  
        break; // restart  
    // 如果这次新插入的节点被另一个线程删除,则使用 findNode 方法将其删除,并停止插入
    if (t.node.value == null) {  
        findNode(key);  
        break splice;  
    }  
    // 到达最底层,结束操作
    if (--insertionLevel == 0)  
        break splice;  
}  
// right指针连接成功,连接下一层,将 t 指针向下移动
if (--j >= insertionLevel && j < level)  
    t = t.down;  
q = q.down;  
r = q.right;  
}

可以看到在连接 right 指针的过程中,可能存在本次插入的节点被另一个线程删除的情况,此时我们会调用 findNode() 方法遍历整个跳表,将所有被标记为删除的节点删除。

查找

查找方法的主要逻辑和基础跳表类似,并且逻辑和上面的 findNode() 逻辑基本相同,读者可以自行阅读以下代码:

private V doGet(Object key) {  
    if (key == null)  
        throw new NullPointerException();  
    Comparator<? super K> cmp = comparator;  
    outer: for (;;) {  
        for (Node<K,V> b = findPredecessor(key, cmp), n = b.next;;) {  
            Object v; int c;  
            if (n == null)  
                break outer;  
            Node<K,V> f = n.next;  
            if (n != b.next)                // inconsistent read  
                break;  
            if ((v = n.value) == null) {    // n is deleted  
                n.helpDelete(b, f);  
                break;            }  
            if (b.value == null || v == n)  // b is deleted  
                break;  
            if ((c = cpr(cmp, key, n.key)) == 0) {  
                @SuppressWarnings("unchecked") V vv = (V)v;  
                return vv;  
            }  
            if (c < 0)  
                break outer;  
            b = n;  
            n = f;  
        }  
    }  
    return null;  
}

总结

ConcurrentSkipListMap 的核心机制就是 CAS 操作。当它在试图修改某个节点的指针时,会使用局部变量保存该指针的原始状态,并利用 CAS 操作进行修改,如果修改失败,则使用 for 循环进行重试。对于删除这种较为危险的操作,ConcurrentSkipListMap 会使用标记的方式,先阻止其他线程对待删除节点的进一步操作,再对该节点进行删除,从而降低删除的危险性。通过标记的方式,也可以让其他线程在操作时帮助进行删除,避免阻塞其他操作,进一步提升写性能。

整体来看,CAS 操作相比互斥锁开销较低,并且能推动开发者更细粒度地对临界区进行控制,减少对其他线程的阻塞,但由于它使用 for 循环进行重试操作,有可能占用一定的 CPU 资源。因此 CAS 操作更适用于使用多核 CPU,并且临界区较小(并发操作共享变量的事件较短)的场景。当然,这些原则都不是绝对的,在实际业务场景中,还是需要通过详细的性能和压力测试选择最适合的锁机制。

扩展阅读

使用 marker 节点后的并发写分析

对于删除这一章节我们分析到的异常情况,假如我们使用了 marker 节点,是否还会出现类似问题呢? 并发删除: 对于链表 b -> n -> f -> g,线程 1 要删除节点 n,线程 2 要删除节点 f。假设它们按照如下顺序执行删除操作:

线程 1线程 2
CAS(n.value, valueN, null)
CAS(f.value, valueF, null)
CAS(n.next, f, markerN)
CAS(f.next, g, markerF)
CAS(b.next, n, f)
CAS(n.next, f, g)

当线程 2 在执行 CAS(n.next, f, g) 操作修改节点 n 的 next 指针时,线程 1 已经将它的 next 指针改为了 markerN 节点,因此这个 CAS 操作会失败。失败后线程 2 会重新返回方法入口,重新进行删除流程。由于此时节点 f 已经被标记为删除状态,它可能被线程 2 或其他线程调用 helpDelete() 方法删除。

对于其他异常情况,读者可以自行分析它的正确性。

无锁链表的其他删除节点实现方式

标志位

在上面的分析中,我们知道引入 marker 节点的主要作用是标记待删除节点的 next 指针,防止它被其他线程篡改。因此除了使用 marker 节点以外,也可以使用单独的标志位来标识 next 指针不能被篡改。 使用标志位这种方法可能需要更多的空间,并且在写入时不仅需要操作节点指针,还需要对标志位进行判断,会引入额外的复杂度。

DCAS

除了标志位这种方法以外,我们还有另一个思路。既然 CAS 能原子地实现先比较再修改的功能,我们是否也可以有另一个原子操作,保证在前驱节点 b 和待删除指针 n 都没被修改过的情况下,再执行删除操作。在实际应用中,我们将这种操作称为 DCAS(Double Compare And Swap)。 假设我们使用 DCAS 来对链表 b -> n -> f 中的节点 n 执行删除操作,对应的 DCAS 操作如下:

DCAS(b.next, n.next, n, f, f);

它对应的含义是:当节点 bnext 指针指向节点 n,且节点 nnext 指针指向节点 f 时,将节点 b 和节点 nnext 指针对应的内存空间指向 f。需要注意的是,DCAS 是直接对内存空间进行操作,此时节点 b 和节点 nnext 指向同一个内存空间,再使用 DCAS 进行修改时,它们会同时被修改。就可以防止上面的异常情况。我们还可以分析一下 DCAS 是如何避免并发删除导致的丢失问题的。

但由于当前主流 CPU 并没有广泛支持 DCAS 操作,并且用 DCAS 操作实现无锁数据结构也不一定比单纯使用 CAS 操作要更加简单,在论文 DCAS is not a silver bullet for nonblocking algorithm design 一文中有更详细的阐述。

参考资料

Last modification:January 29, 2024
如果觉得我的文章对你有用,请随意赞赏