【源码阅读】ConcurrentHashMap 1.7

一、为什么引入 ConcurrentHashMap?

  • HashMap 可能会在扩容时的 transfer 操作发生并发问题导致链表循环引用,导致在进行 get 操作时发生死循环。
  • HashTable 可以提供并发功能,但它是用 synchronize 关键字修饰每一个存在并发需求的方法上,也就是给整个 table 都加了锁,在多线程环境下,可能存在所有线程都等正竞争一把锁的情况,这也就造成了效率低下的问题。-

  • Q1:怎么解决/优化上述问题?
  • A1:采用对 table 的不同数据集分别加锁的方案代替锁住整个 table 的数据。
  • Q2:上述方案可行的原理是什么?
  • A2:我们知道 hash 值不同,在 rehash 时并不会造成线程安全问题,所以分别锁住别个数据段是可行的。

二、源码阅读

(1) 底层数据结构

在 JDK1.7 版本中,ConcurrentHashMap 的数据结构是由一个内含多个 HashEntry 组 的 Segment 数组构成。先总览一下 ConcurrentHashMap 几个重要的成员变量:

	//默认的数组大小16(HashMap里的那个数组)
	static final int DEFAULT_INITIAL_CAPACITY = 16;
	
	//扩容因子0.75
	static final float DEFAULT_LOAD_FACTOR = 0.75f;
	 
	//ConcurrentHashMap中的数组
	final Segment<K,V>[] segments
	
	//默认并发标准16
	static final int DEFAULT_CONCURRENCY_LEVEL = 16;
	
	//Segment是ReentrantLock子类,因此拥有锁的操作
	static final class Segment<K,V> extends ReentrantLock implements Serializable {
		//HashMap的那一套,分别是数组、键值对数量、阈值、负载因子
		transient volatile HashEntry<K,V>[] table;
		transient int count;
		transient int threshold;
		final float loadFactor;
		
		Segment(float lf, int threshold, HashEntry<K,V>[] tab) {
			this.loadFactor = lf;
			this.threshold = threshold;
			this.table = tab;
		}
	}
	static final class HashEntry<K,V> {
        final int hash;
        final K key;
        volatile V value;
        volatile HashEntry<K,V> next;
	}
	//segment中HashEntry[]数组最小长度
	static final int MIN_SEGMENT_TABLE_CAPACITY = 2;
	
	//用于定位在segments数组中的位置,下面介绍
	final int segmentMask;
	final int segmentShift;

看下 Segment 数组,它的意义是:将一个 Segment 分割成多个小的 table 来进行加锁,也就是上面的提到的分段锁,而每一个 Segment 元素存储的是 HashEntry 数组,每一个 HashEntry (数组+链表)

有没有觉得 Segment 很像 HashMap 的组成...

static final class Segment<K,V> extends ReentrantLock implements Serializable {

	private static final long serialVersionUID = 2249069246763182397L;
	
	// 和 HashMap 中的 HashEntry 作用一样,真正存放数据的桶
	transient volatile HashEntry<K,V>[] table;
	
	transient int count;
	
	transient int modCount;
	
	transient int threshold;
	
	final float loadFactor;
}

再看下 HashEntry 的内部构造

static final class HashEntry<K,V> {
	final int hash;
	final K key;
	volatile V value;
	volatile HashEntry<K,V> next;
	//其他省略...
}

(2) 构造方法

public ConcurrentHashMap(int initialCapacity,
                               float loadFactor, int concurrencyLevel) {
     if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
         throw new IllegalArgumentException();
     // 最大并发数为 1<<16=65536
     if (concurrencyLevel > MAX_SEGMENTS)
         concurrencyLevel = MAX_SEGMENTS;
     // 2 的sshif次方等于ssize,例:ssize=16,sshift=4;ssize=32,sshif=5
    int sshift = 0;
    //ssize 为segments数组长度,根据concurrentLevel计算得出
    int ssize = 1;
    while (ssize < concurrencyLevel) {
        ++sshift;
        ssize <<= 1;
    }
    //默认值,concurrencyLevel 为 16,sshift 为 4,那么计算出 segmentShift 为 28,segmentMask 为 15
    //segmentShift和segmentMask这两个变量在定位segment位置时会用到,后面会详细讲
    this.segmentShift = 32 - sshift;
    this.segmentMask = ssize - 1;
    if (initialCapacity > MAXIMUM_CAPACITY)
        initialCapacity = MAXIMUM_CAPACITY;
    /* initialCapacity设置整个 segm 的初始容量,根据 initialCapacity 计算 Segment 数组的每个位置可以分配的大小
       如 initialCapacity 为 64,那么每个 Segment 可以分到大小为 4 的 HashEntry 数组*/
    int c = initialCapacity / ssize;
    if (c * ssize < initialCapacity)
        ++c;
    // 默认 MIN_SEGMENT_TABLE_CAPACITY 是 2,为什么是 2 呢?因为对于具体的槽上,插入一个元素不会立刻扩容
    int cap = MIN_SEGMENT_TABLE_CAPACITY;
    while (cap < c)
        cap <<= 1;
    //创建segments数组并初始化第一个Segment,其余的Segment延迟初始化
    Segment<K,V> s0 =
        new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
                         (HashEntry<K,V>[])new HashEntry[cap]);
    Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
    UNSAFE.putOrderedObject(ss, SBASE, s0);
    this.segments = ss;
}

总结一下,无外乎几步:

  • 确定最大并发数 concurrencyLevel(最大为 2^{16})
  • 根据 concurrencyLevel 确定 sshift 和 ssize
    • ssize 初始值为 1,通过左移得到一个 >= concurrencyLevel 的最小 2 的幂次方数。
    • sshift 表示 sszie 左移的次数。
  • 根据 sshift 算出 segmentShift = 32 - sshift 根据 ssize 算出 segmentMask = ssize - 1,因为 ssize 的值为 2^n,所以减 1 就变成为二进制位中第 n 位以下全是 1 的二进制数。
  • 确定每个槽的初始容量 initialCapacity,但最大不能超过 MAXIMUM_CAPACITY = 2
  • 根据 ssize 和 initialCapacity 算出每个槽中 HashEntry 数组的长度 cap(这也一定要是一个 2 的幂次方数,具体看代码)
  • 根据 loadFactor (0.75)、cap 创建第一个 Segment 对象 s0,这个是有讲究的,后面在说。

(3) put 方法

public V put(K key, V value) {
    Segment<K,V> s;
    if (value == null)	   //1. ConcurrentHashMap不允许valus为空
        throw new NullPointerException();
    int hash = hash(key);  //2. 根据key计算hash值,key也不能为null,否则hash(key)报空指针
    //3. 根据hash值计算在segments数组中的位置
      // hash 是 32 位,无符号右移 segmentShift(28) 位,用剩下的高 4 位,
      // 和 segmentMask(15) 做一次与操作,也就是说数组下标 j 是 hash 值的高 4 位
    int j = (hash >>> segmentShift) & segmentMask; 
    //4. 取第Segment数组的 j 个位置的元素
    if ((s = (Segment<K,V>)UNSAFE.getObject(segments, (j << SSHIFT) + SBASE)) == null) // nonvolatile + in ensureSegment
        s = ensureSegment(j);
    return s.put(key, hash, value, false);
}

总结一下:

  • 调用 hash 方法,算出非空 key 的 32 位整数变量 hash
  • 用 hash 右移 segmentShift 位的结果对 segmentMask 进行按位与 & 操作得到 Segment 数组的下标 j:
    • 如果该位置的 Segment 还没有初始化,就会通过 CAS 操作对位置的 Segment 代用 ensureSegment 方法进行赋值。
    • 否则,直接调用 Segment 的 put 方法进行赋值。

CAS(compare and swap):即比较并交换,CAS 机制当中使用了 3 个基本操作数:

  • 内存地址 V,预期的旧值 A,要修改的新值 B。
  • 更新一个变量的时候,只有当变量的预期的旧值 A 和内存地址 V 当中的实际值相同时,才会将内存地址 V 对应的值修改为 B。

接着看一下 ensureSegment 方法的逻辑

3.1 ensureSegment 方法

private Segment<K,V> ensureSegment(int k) {
    final Segment<K,V>[] ss = this.segments;
    long u = (k << SSHIFT) + SBASE; // raw offset
    Segment<K,V> seg;
    // 在 seg 的第 j 个位置为空的情况下(见put方法)进到该方法内部
    // 因为可能存在并发情况,故要检查是否被其他线程先初始化了 seg 的 u 位置,是就先返回
    if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
        // 这里相当于利用seg[0]来初始化了一个"HashMap"
        // 可以直接使用"当前"的 seg[0] 处的数组长度和负载因子来初始化 segment[k],省去一些琐碎的计算 cap、lf、thre...
        // 为什么说“当前”呢,是因为seg[0]可能被其他线程修改过(扩容等等)
        Segment<K,V> proto = ss[0];
        int cap = proto.table.length;
        float lf = proto.loadFactor;
        int threshold = (int)(cap * lf);
        // 初始化 segment[u] 内部的数组
        HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
        // 再次检查该槽是否被其他线程初始化,尽量减少下面的初始化操作
        if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))== null) { 
            Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
            // 使用保险的自旋:用 CAS 一直检查,直到当前线程成功设值或其他线程成功设值后才退出
            while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
            	// 使用 cas 操作来更新值(内存地址ss[u]、预期值null、新值 seg),存在两种情况:
            	 // 1.因为被其他线程抢先操作了(不等于预期值 null),所以更新失败,然后继续循环直到满足预期值为止
            	 // 2.更新成功,break
                if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
                    break;
            }
        }
    }
    return seg;
}

看完我有点感叹:不得不说高并发设计是多么严格啊,只要我还没有到初始化成功,我就要用 CAS 机制检查我要赋值的位置是否被其他线程先初始化了!

总结一下:

  • 根据索引 k 计算出 segment 数组的第 u 个位置
  • 检查 segment 数组的第 u 个位置是否已经有值
    • 没有,则继续利用 ss[0] 的属性来初始化 seg
    • 有则,返回该位置的已存在的值
  • 因为中间经历了一些耗时的初始化动作,所以又检查了一遍 ss[u] 是否有值,如果没有,继续按部就班。
  • 最后利用 while + if 的 cas 自旋操作一直检查到,cas 成功为止。

执行完 ConcurrentHashMap 的 put 方法,接下来就是执行返回的 Segment 对象的的 put 方法了

3.2 Segment.put()

这个方法作用就是在获取的 segment 对象内部的 HashEntry 数组/链表中放入/插入参数 key、value 等元素

final V put(K key, int hash, V value, boolean onlyIfAbsent) {
    // 在往该 segment 对象放入之前,先非阻塞地尝试获取该 segment 对象的锁
    HashEntry<K,V> node = tryLock() ? null : scanAndLockForPut(key, hash, value);
    V oldValue;
    try {
        HashEntry<K,V>[] tab = table; 		 //获取segment对象内部的数组
        int index = (tab.length - 1) & hash; //再利用参数hash求出放置k、v的数组下标
        HashEntry<K,V> first = entryAt(tab, index); // 获取数组index位置的链表表头
        for (HashEntry<K,V> e = first;;) {	//遍历链表
            if (e != null) {
                K k;						//if操作检查是否需要覆盖旧值
                if ((k = e.key) == key || (e.hash == hash && key.equals(k))) {	
                    oldValue = e.value;
                    if (onlyIfAbsent == false) {//是否替换取决于调用者的意愿
                        e.value = value;
                        ++modCount;
                    }
                    break;
                }
                e = e.next;	
            }else {// 判断node到底是否为null,这个要看获取锁的过程,不过和这里都没有关系。
                   // 如果不为 null,使用头插法插在链表表头;如果是null,初始化并设置为链表表头。
                if (node != null) node.setNext(first);
                else			  node = new HashEntry<K,V>(hash, key, value, first);
                int c = count + 1;
                	  // 超过了该 segment 的阈值,这个 segment 需要扩容
                if (c > threshold && tab.length < MAXIMUM_CAPACITY)
                    rehash(node);
                else // 没有达到阈值,将 node 放到数组 tab 的 index 位置
                	 // 注:里面是调用UNSAFE的put方法保证将对象存到内存中,而不是仅仅插在线程的工作空间中
                    setEntryAt(tab, index, node);	// 链表下移
                ++modCount;
                count = c;
                oldValue = null;
                break;
            }
        }
    } finally {
        unlock(); //解锁
    }
    return oldValue;
}

对 scanAndLockForPut 方法有一些疑惑,不妨看看这个场景:

现在 Thread1 调用当前 seg[5] 对象的 put 方法存值,假设它可成功拿到锁,根据计算,得出它要存的键值对应该放在 HashEntry[] 的 0 号位置,0 号位置为空,于是新建一个 HashEntry,并通过 setEntryAt() 方法,放在 0 号位置,然而还没等 Thread1 释放锁,系统的时间片切到了 Thread2 ,先画图存档:

在这里插入图片描述

此时正好 Thread2 也来存值,通过下标计算,Thread2 被定位到 seg[5] 中 HashEntry[] 的 0 号位置,接下来 Thread2 也调用当前 seg 对象的 put 方法,一开始先尝试获取锁,没有成功 (Thread1 还未释放,没有插入完毕),就会去执行 scanAndLockForPut() 方法:

private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
    HashEntry<K,V> first = entryForHash(this, hash); //根据hash值获取当前HashEntry数组的对应位置的结点
    HashEntry<K,V> e = first;
    HashEntry<K,V> node = null;
    int retries = -1;	  // negative while locating node,控制分支
   
    while (!tryLock()) {  // 自旋并尝试获取锁
        HashEntry<K,V> f; // to recheck first below
        if (retries < 0) {
            if (e == null) {
                if (node == null) // 进到这里说明数组该位置的链表是空的,没有任何元素
                    node = new HashEntry<K,V>(hash, key, value, null);
                retries = 0;
            }
            else if (key.equals(e.key)) retries = 0;//如果发现key重复证明该位置的值不空
            else						e = e.next; //否则继续遍历
        }
        // 重试次数如果超过 MAX_SCAN_RETRIES(单核重试1次多核64次)
        else if (++retries > MAX_SCAN_RETRIES) {
            lock();			// lock() 是阻塞方法,直到获取锁后返回
            break;
        }else if ((retries & 1) == 0 && (f = entryForHash(this, hash)) != first) {				
        	//每间隔2次重试就检查链表是否发生被其它线程膝盖,如果是,则重新自旋获取锁
            e = first = f; 	// re-traverse if entry changed
            retries = -1;
        }
    }
    return node;
}

虽然 Thread2 没有获取到锁,但它并不是闲着,而是在进入 scanAndLockForPut 方法等待锁的过程中,预先计算好自己要存放的键值对的在 seg[5] 中的相应位置,以便拿到锁时立刻执行赋值操作,达到节省时间的作用。

  • Q1:那为什么在自旋的时候还要去检查链表是否被改变了呢?
  • A1:这是因为当 Thread2 确定了插入的位置在 0 号位置,但 Thread1 已经完成插入了,那么此时根据 new HashEntry<K,V>(hash, key, value, first) 计算出来的值可能会造成 hash 冲突,所以要重新咯...

看到这里,其实整体感觉也不难,总结一下:

  • 非阻塞地尝试获取 seg 对象的可重入锁
    • 这里使用了非阻塞的 tryLock 去获取锁
    • 题外话:另外,如果是阻塞式地获取锁就应该调用可重入锁的 lk.lock() 方法
  • 使用头插法插入到合适的位置,位置可能是数组也可能是链表
  • 插入完毕还有一些是否需要扩容的检查,下面会讲到。

3.3 Segment.rehash()

private void rehash(HashEntry<K,V> node) {
    HashEntry<K,V>[] oldTable = table; //获取原table
    int oldCapacity = oldTable.length;
    int newCapacity = oldCapacity << 1;//长度取原长的2倍
    threshold = (int)(newCapacity * loadFactor);//得到新的阈值
    HashEntry<K,V>[] newTable =		   //创建新table
        (HashEntry<K,V>[]) new HashEntry[newCapacity];
    int sizeMask = newCapacity - 1;	   //新的掩码,如从16扩容到32,那么 sizeMask 为 31,对应二进制 ‘000...00011111’
    // 遍历原数组,将原数组位置 i 处的链表拆分到新数组位置 i 和 i+oldCap 两个位置
    for (int i = 0; i < oldCapacity ; i++) {
        //e是链表表头
        HashEntry<K,V> e = oldTable[i];
        if (e != null) {		 //不空才进行数据迁移
            HashEntry<K,V> next = e.next;
            //重定位,假设原数组长度为16,e 在oldTable[3]处,那么idx只可能是3或者是3 + 16 = 19
            int idx = e.hash & sizeMask;
            if (next == null)  			   //数组的链表只有一个元素
                newTable[idx] = e;
            else { 						   //循环迁移
                HashEntry<K,V> lastRun = e;//e 是链表表头
                int lastIdx = idx;		   //idx是当前链表的头结点 e 的新位置
                //该for循环会找到一个lastRun节点,区间[lastRun, end]中的结点的下标都是一样的,所以在新数组的位置是一样的
                for (HashEntry<K,V> last = next; last != null; last = last.next) {
                    int k = last.hash & sizeMask;
                    if (k != lastIdx) {	   //只有下标不一样才会更新
                        lastIdx = k;
                        lastRun = last;
                    }
                }
                newTable[lastIdx] = lastRun; //直接赋值就不用将lastRun后面的所有结点一个一个地插入
                //下面的for是迁移lastRun前面的节点
                //这些节点可能分配在另一个链表中,也可能分配到上面的那个链表中
                for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
                    V v = p.value;
                    int h = p.hash;
                    int k = h & sizeMask;
                    HashEntry<K,V> n = newTable[k];
                    newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
                }
            }
        }
    }
    // 将要插入的node结点放到新数组中某个位置链表头部
    int nodeIndex = node.hash & sizeMask;
    node.setNext(newTable[nodeIndex]);
    newTable[nodeIndex] = node;
    table = newTable;
}

rehash 是在加锁的 put 方法中调用的,所以不会产生线程不安全问题。逻辑也比较清楚,总结一下:

  • 计算新长度、新阈值、新掩码。
  • 遍历老数组,将位置 i 的元素迁移到新数组的两个位置 i / i + oldCap 之一中去,中间涉及到了一些细节,可看上述代码注释。
  • 最后就将,要插入的新 node 插到新数组中某个位置链表头部。

好,至此 ConcurrentHashMap 的 put 方法也就讲解完毕,下面到它的 get 方法...

(4) get 方法

get 方法相对比较简单

public V get(Object key) {
    Segment<K,V> s; // manually integrate access methods to reduce overhead
    HashEntry<K,V>[] tab;
    int h = hash(key); 	// 1.获取hash值
    long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
    //2.根据hash值找到对应的 segment
    if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null && (tab = s.table) != null) {
        //3.从内存中找到segment对象相应位置的数组中的链表(防止并发问题)
        for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile(tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE); e != null; e = e.next) {
            K k;
            if ((k = e.key) == key || (e.hash == h && key.equals(k)))
                return e.value;
        }
    }
    return null;
}

(5) size 方法

public int size() {
    final Segment < K, V > [] segments = this.segments;
    int size;
    boolean overflow; // true if size overflows 32 bits
    long sum; 		  // 总修改次数
    long last = 0 L;  // 上一次的总修改次数
    int retries = -1; // first iteration isn't retry
    try {
        for(;;) { //如果遍历次数达到2次以上,证明在第二次遍历时存在并发修改问题,故在第三次遍历时,对每个seg对象加上锁
            if(retries++ == RETRIES_BEFORE_LOCK) {//RETRIES_BEFORE_LOCK=2
                for(int j = 0; j < segments.length; ++j) 
                	ensureSegment(j).lock();
            }
            sum = 0 L;
            size = 0;
            overflow = false;
            for(int j = 0; j < segments.length; ++j) {
                Segment < K, V > seg = segmentAt(segments, j);
                if(seg != null) {
                    sum += seg.modCount; //得到当前seg对象的修改次数(put、remove)
                    int c = seg.count;	 //得到单个seg的大小
                    if(c < 0 || (size += c) < 0) //标记是否如果产生溢出
                    	overflow = true;
                }
            }
            if(sum == last) 			//直到和上一次修改总次数得到的总和相等,ConcurrentHashMap 没有被修改过
            	break;	
            last = sum;
        }
    } finally {
        if(retries > RETRIES_BEFORE_LOCK) {
            for(int j = 0; j < segments.length; ++j) 
            	segmentAt(segments, j).unlock(); //解锁
        }
    }
    return overflow ? Integer.MAX_VALUE : size;	 //如果发生溢出,返回最大整形
}
  • Q1:为什么方法中的 modCount 只增不减,这样设计的目的是什么?
  • A1:还是从并发的角度来分析,这样设计的目的是避免一个线程 put 元素和另一个线程 remove 元素后抵消了前面线程的 put 动作的 modCount,进而避免在统计 size 的时候产生死循环问题。

总结一下:

  • 前后两次计算出 ConcurrentHashMap 内部的每一个 Segment 对象的 modCount 总和到 sum 和 last 中。
  • 如果两次遍历得到的结果不同,即 sum != last 则证明存在线程并发修改,到第三次遍历就会对每一个 seg 对象都加上锁,然后再次遍历,直到 sum = last 退出循环
  • 其中需要记录 map 中的元素总数是否发生溢出。

恢复: https://www.javadoop.com/post/hashmap#toc_3 https://www.jianshu.com/p/9c713de7bbdb https://juejin.im/post/5a2f2f7851882554b837823a https://www.cnblogs.com/study-everyday/p/6430462.html#autoid-2-0-0 https://www.cnblogs.com/chengxiao/p/6842045.html

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章