一、为什么引入 ConcurrentHashMap?
- HashMap 可能会在扩容时的 transfer 操作发生并发问题导致链表循环引用,导致在进行 get 操作时发生死循环。
- HashTable 可以提供并发功能,但它是用 synchronize 关键字修饰每一个存在并发需求的方法上,也就是给整个 table 都加了锁,在多线程环境下,可能存在所有线程都等正竞争一把锁的情况,这也就造成了效率低下的问题。-
- Q1:怎么解决/优化上述问题?
- A1:采用对 table 的不同数据集分别加锁的方案代替锁住整个 table 的数据。
- Q2:上述方案可行的原理是什么?
- A2:我们知道 hash 值不同,在 rehash 时并不会造成线程安全问题,所以分别锁住别个数据段是可行的。
二、源码阅读
(1) 底层数据结构
在 JDK1.7 版本中,ConcurrentHashMap 的数据结构是由一个内含多个 HashEntry 组 的 Segment 数组构成。先总览一下 ConcurrentHashMap 几个重要的成员变量:
//默认的数组大小16(HashMap里的那个数组)
static final int DEFAULT_INITIAL_CAPACITY = 16;
//扩容因子0.75
static final float DEFAULT_LOAD_FACTOR = 0.75f;
//ConcurrentHashMap中的数组
final Segment<K,V>[] segments
//默认并发标准16
static final int DEFAULT_CONCURRENCY_LEVEL = 16;
//Segment是ReentrantLock子类,因此拥有锁的操作
static final class Segment<K,V> extends ReentrantLock implements Serializable {
//HashMap的那一套,分别是数组、键值对数量、阈值、负载因子
transient volatile HashEntry<K,V>[] table;
transient int count;
transient int threshold;
final float loadFactor;
Segment(float lf, int threshold, HashEntry<K,V>[] tab) {
this.loadFactor = lf;
this.threshold = threshold;
this.table = tab;
}
}
static final class HashEntry<K,V> {
final int hash;
final K key;
volatile V value;
volatile HashEntry<K,V> next;
}
//segment中HashEntry[]数组最小长度
static final int MIN_SEGMENT_TABLE_CAPACITY = 2;
//用于定位在segments数组中的位置,下面介绍
final int segmentMask;
final int segmentShift;
看下 Segment 数组,它的意义是:将一个 Segment 分割成多个小的 table 来进行加锁,也就是上面的提到的分段锁,而每一个 Segment 元素存储的是 HashEntry 数组,每一个 HashEntry (数组+链表)
有没有觉得 Segment 很像 HashMap 的组成...
static final class Segment<K,V> extends ReentrantLock implements Serializable {
private static final long serialVersionUID = 2249069246763182397L;
// 和 HashMap 中的 HashEntry 作用一样,真正存放数据的桶
transient volatile HashEntry<K,V>[] table;
transient int count;
transient int modCount;
transient int threshold;
final float loadFactor;
}
再看下 HashEntry 的内部构造
static final class HashEntry<K,V> {
final int hash;
final K key;
volatile V value;
volatile HashEntry<K,V> next;
//其他省略...
}
(2) 构造方法
public ConcurrentHashMap(int initialCapacity,
float loadFactor, int concurrencyLevel) {
if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
throw new IllegalArgumentException();
// 最大并发数为 1<<16=65536
if (concurrencyLevel > MAX_SEGMENTS)
concurrencyLevel = MAX_SEGMENTS;
// 2 的sshif次方等于ssize,例:ssize=16,sshift=4;ssize=32,sshif=5
int sshift = 0;
//ssize 为segments数组长度,根据concurrentLevel计算得出
int ssize = 1;
while (ssize < concurrencyLevel) {
++sshift;
ssize <<= 1;
}
//默认值,concurrencyLevel 为 16,sshift 为 4,那么计算出 segmentShift 为 28,segmentMask 为 15
//segmentShift和segmentMask这两个变量在定位segment位置时会用到,后面会详细讲
this.segmentShift = 32 - sshift;
this.segmentMask = ssize - 1;
if (initialCapacity > MAXIMUM_CAPACITY)
initialCapacity = MAXIMUM_CAPACITY;
/* initialCapacity设置整个 segm 的初始容量,根据 initialCapacity 计算 Segment 数组的每个位置可以分配的大小
如 initialCapacity 为 64,那么每个 Segment 可以分到大小为 4 的 HashEntry 数组*/
int c = initialCapacity / ssize;
if (c * ssize < initialCapacity)
++c;
// 默认 MIN_SEGMENT_TABLE_CAPACITY 是 2,为什么是 2 呢?因为对于具体的槽上,插入一个元素不会立刻扩容
int cap = MIN_SEGMENT_TABLE_CAPACITY;
while (cap < c)
cap <<= 1;
//创建segments数组并初始化第一个Segment,其余的Segment延迟初始化
Segment<K,V> s0 =
new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
(HashEntry<K,V>[])new HashEntry[cap]);
Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
UNSAFE.putOrderedObject(ss, SBASE, s0);
this.segments = ss;
}
总结一下,无外乎几步:
- 确定最大并发数 concurrencyLevel(最大为 2^{16})
- 根据 concurrencyLevel 确定 sshift 和 ssize
- ssize 初始值为 1,通过左移得到一个 >= concurrencyLevel 的最小 2 的幂次方数。
- sshift 表示 sszie 左移的次数。
- 根据 sshift 算出 segmentShift = 32 - sshift 根据 ssize 算出 segmentMask = ssize - 1,因为 ssize 的值为 2^n,所以减 1 就变成为二进制位中第 n 位以下全是 1 的二进制数。
- 确定每个槽的初始容量 initialCapacity,但最大不能超过 MAXIMUM_CAPACITY = 2
- 根据 ssize 和 initialCapacity 算出每个槽中 HashEntry 数组的长度 cap(这也一定要是一个 2 的幂次方数,具体看代码)
- 根据 loadFactor (0.75)、cap 创建第一个 Segment 对象 s0,这个是有讲究的,后面在说。
(3) put 方法
public V put(K key, V value) {
Segment<K,V> s;
if (value == null) //1. ConcurrentHashMap不允许valus为空
throw new NullPointerException();
int hash = hash(key); //2. 根据key计算hash值,key也不能为null,否则hash(key)报空指针
//3. 根据hash值计算在segments数组中的位置
// hash 是 32 位,无符号右移 segmentShift(28) 位,用剩下的高 4 位,
// 和 segmentMask(15) 做一次与操作,也就是说数组下标 j 是 hash 值的高 4 位
int j = (hash >>> segmentShift) & segmentMask;
//4. 取第Segment数组的 j 个位置的元素
if ((s = (Segment<K,V>)UNSAFE.getObject(segments, (j << SSHIFT) + SBASE)) == null) // nonvolatile + in ensureSegment
s = ensureSegment(j);
return s.put(key, hash, value, false);
}
总结一下:
- 调用 hash 方法,算出非空 key 的 32 位整数变量 hash
- 用 hash 右移 segmentShift 位的结果对 segmentMask 进行按位与 & 操作得到 Segment 数组的下标 j:
- 如果该位置的 Segment 还没有初始化,就会通过 CAS 操作对位置的 Segment 代用 ensureSegment 方法进行赋值。
- 否则,直接调用 Segment 的 put 方法进行赋值。
CAS(compare and swap):即比较并交换,CAS 机制当中使用了 3 个基本操作数:
- 内存地址 V,预期的旧值 A,要修改的新值 B。
- 更新一个变量的时候,只有当变量的预期的旧值 A 和内存地址 V 当中的实际值相同时,才会将内存地址 V 对应的值修改为 B。
接着看一下 ensureSegment 方法的逻辑
3.1 ensureSegment 方法
private Segment<K,V> ensureSegment(int k) {
final Segment<K,V>[] ss = this.segments;
long u = (k << SSHIFT) + SBASE; // raw offset
Segment<K,V> seg;
// 在 seg 的第 j 个位置为空的情况下(见put方法)进到该方法内部
// 因为可能存在并发情况,故要检查是否被其他线程先初始化了 seg 的 u 位置,是就先返回
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
// 这里相当于利用seg[0]来初始化了一个"HashMap"
// 可以直接使用"当前"的 seg[0] 处的数组长度和负载因子来初始化 segment[k],省去一些琐碎的计算 cap、lf、thre...
// 为什么说“当前”呢,是因为seg[0]可能被其他线程修改过(扩容等等)
Segment<K,V> proto = ss[0];
int cap = proto.table.length;
float lf = proto.loadFactor;
int threshold = (int)(cap * lf);
// 初始化 segment[u] 内部的数组
HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
// 再次检查该槽是否被其他线程初始化,尽量减少下面的初始化操作
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))== null) {
Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
// 使用保险的自旋:用 CAS 一直检查,直到当前线程成功设值或其他线程成功设值后才退出
while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
// 使用 cas 操作来更新值(内存地址ss[u]、预期值null、新值 seg),存在两种情况:
// 1.因为被其他线程抢先操作了(不等于预期值 null),所以更新失败,然后继续循环直到满足预期值为止
// 2.更新成功,break
if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
break;
}
}
}
return seg;
}
看完我有点感叹:不得不说高并发设计是多么严格啊,只要我还没有到初始化成功,我就要用 CAS 机制检查我要赋值的位置是否被其他线程先初始化了!
总结一下:
- 根据索引 k 计算出 segment 数组的第 u 个位置
- 检查 segment 数组的第 u 个位置是否已经有值
- 没有,则继续利用 ss[0] 的属性来初始化 seg
- 有则,返回该位置的已存在的值
- 因为中间经历了一些耗时的初始化动作,所以又检查了一遍 ss[u] 是否有值,如果没有,继续按部就班。
- 最后利用 while + if 的 cas 自旋操作一直检查到,cas 成功为止。
执行完 ConcurrentHashMap 的 put 方法,接下来就是执行返回的 Segment 对象的的 put 方法了
3.2 Segment.put()
这个方法作用就是在获取的 segment 对象内部的 HashEntry 数组/链表中放入/插入参数 key、value 等元素
final V put(K key, int hash, V value, boolean onlyIfAbsent) {
// 在往该 segment 对象放入之前,先非阻塞地尝试获取该 segment 对象的锁
HashEntry<K,V> node = tryLock() ? null : scanAndLockForPut(key, hash, value);
V oldValue;
try {
HashEntry<K,V>[] tab = table; //获取segment对象内部的数组
int index = (tab.length - 1) & hash; //再利用参数hash求出放置k、v的数组下标
HashEntry<K,V> first = entryAt(tab, index); // 获取数组index位置的链表表头
for (HashEntry<K,V> e = first;;) { //遍历链表
if (e != null) {
K k; //if操作检查是否需要覆盖旧值
if ((k = e.key) == key || (e.hash == hash && key.equals(k))) {
oldValue = e.value;
if (onlyIfAbsent == false) {//是否替换取决于调用者的意愿
e.value = value;
++modCount;
}
break;
}
e = e.next;
}else {// 判断node到底是否为null,这个要看获取锁的过程,不过和这里都没有关系。
// 如果不为 null,使用头插法插在链表表头;如果是null,初始化并设置为链表表头。
if (node != null) node.setNext(first);
else node = new HashEntry<K,V>(hash, key, value, first);
int c = count + 1;
// 超过了该 segment 的阈值,这个 segment 需要扩容
if (c > threshold && tab.length < MAXIMUM_CAPACITY)
rehash(node);
else // 没有达到阈值,将 node 放到数组 tab 的 index 位置
// 注:里面是调用UNSAFE的put方法保证将对象存到内存中,而不是仅仅插在线程的工作空间中
setEntryAt(tab, index, node); // 链表下移
++modCount;
count = c;
oldValue = null;
break;
}
}
} finally {
unlock(); //解锁
}
return oldValue;
}
对 scanAndLockForPut 方法有一些疑惑,不妨看看这个场景:
现在 Thread1 调用当前 seg[5] 对象的 put 方法存值,假设它可成功拿到锁,根据计算,得出它要存的键值对应该放在 HashEntry[] 的 0 号位置,0 号位置为空,于是新建一个 HashEntry,并通过 setEntryAt() 方法,放在 0 号位置,然而还没等 Thread1 释放锁,系统的时间片切到了 Thread2 ,先画图存档:
此时正好 Thread2 也来存值,通过下标计算,Thread2 被定位到 seg[5] 中 HashEntry[] 的 0 号位置,接下来 Thread2 也调用当前 seg 对象的 put 方法,一开始先尝试获取锁,没有成功 (Thread1 还未释放,没有插入完毕),就会去执行 scanAndLockForPut() 方法:
private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
HashEntry<K,V> first = entryForHash(this, hash); //根据hash值获取当前HashEntry数组的对应位置的结点
HashEntry<K,V> e = first;
HashEntry<K,V> node = null;
int retries = -1; // negative while locating node,控制分支
while (!tryLock()) { // 自旋并尝试获取锁
HashEntry<K,V> f; // to recheck first below
if (retries < 0) {
if (e == null) {
if (node == null) // 进到这里说明数组该位置的链表是空的,没有任何元素
node = new HashEntry<K,V>(hash, key, value, null);
retries = 0;
}
else if (key.equals(e.key)) retries = 0;//如果发现key重复证明该位置的值不空
else e = e.next; //否则继续遍历
}
// 重试次数如果超过 MAX_SCAN_RETRIES(单核重试1次多核64次)
else if (++retries > MAX_SCAN_RETRIES) {
lock(); // lock() 是阻塞方法,直到获取锁后返回
break;
}else if ((retries & 1) == 0 && (f = entryForHash(this, hash)) != first) {
//每间隔2次重试就检查链表是否发生被其它线程膝盖,如果是,则重新自旋获取锁
e = first = f; // re-traverse if entry changed
retries = -1;
}
}
return node;
}
虽然 Thread2 没有获取到锁,但它并不是闲着,而是在进入 scanAndLockForPut 方法等待锁的过程中,预先计算好自己要存放的键值对的在 seg[5] 中的相应位置,以便拿到锁时立刻执行赋值操作,达到节省时间的作用。
Q1:那为什么在自旋的时候还要去检查链表是否被改变了呢?- A1:这是因为当 Thread2 确定了插入的位置在 0 号位置,但 Thread1 已经完成插入了,那么此时根据 new HashEntry<K,V>(hash, key, value, first) 计算出来的值可能会造成 hash 冲突,所以要重新咯...
看到这里,其实整体感觉也不难,总结一下:
- 非阻塞地尝试获取 seg 对象的可重入锁
- 这里使用了非阻塞的 tryLock 去获取锁
题外话:另外,如果是阻塞式地获取锁就应该调用可重入锁的 lk.lock() 方法
- 使用头插法插入到合适的位置,位置可能是数组也可能是链表
- 插入完毕还有一些是否需要扩容的检查,下面会讲到。
3.3 Segment.rehash()
private void rehash(HashEntry<K,V> node) {
HashEntry<K,V>[] oldTable = table; //获取原table
int oldCapacity = oldTable.length;
int newCapacity = oldCapacity << 1;//长度取原长的2倍
threshold = (int)(newCapacity * loadFactor);//得到新的阈值
HashEntry<K,V>[] newTable = //创建新table
(HashEntry<K,V>[]) new HashEntry[newCapacity];
int sizeMask = newCapacity - 1; //新的掩码,如从16扩容到32,那么 sizeMask 为 31,对应二进制 ‘000...00011111’
// 遍历原数组,将原数组位置 i 处的链表拆分到新数组位置 i 和 i+oldCap 两个位置
for (int i = 0; i < oldCapacity ; i++) {
//e是链表表头
HashEntry<K,V> e = oldTable[i];
if (e != null) { //不空才进行数据迁移
HashEntry<K,V> next = e.next;
//重定位,假设原数组长度为16,e 在oldTable[3]处,那么idx只可能是3或者是3 + 16 = 19
int idx = e.hash & sizeMask;
if (next == null) //数组的链表只有一个元素
newTable[idx] = e;
else { //循环迁移
HashEntry<K,V> lastRun = e;//e 是链表表头
int lastIdx = idx; //idx是当前链表的头结点 e 的新位置
//该for循环会找到一个lastRun节点,区间[lastRun, end]中的结点的下标都是一样的,所以在新数组的位置是一样的
for (HashEntry<K,V> last = next; last != null; last = last.next) {
int k = last.hash & sizeMask;
if (k != lastIdx) { //只有下标不一样才会更新
lastIdx = k;
lastRun = last;
}
}
newTable[lastIdx] = lastRun; //直接赋值就不用将lastRun后面的所有结点一个一个地插入
//下面的for是迁移lastRun前面的节点
//这些节点可能分配在另一个链表中,也可能分配到上面的那个链表中
for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
V v = p.value;
int h = p.hash;
int k = h & sizeMask;
HashEntry<K,V> n = newTable[k];
newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
}
}
}
}
// 将要插入的node结点放到新数组中某个位置链表头部
int nodeIndex = node.hash & sizeMask;
node.setNext(newTable[nodeIndex]);
newTable[nodeIndex] = node;
table = newTable;
}
rehash 是在加锁的 put 方法中调用的,所以不会产生线程不安全问题。逻辑也比较清楚,总结一下:
- 计算新长度、新阈值、新掩码。
- 遍历老数组,将位置 i 的元素迁移到新数组的两个位置
i / i + oldCap
之一中去,中间涉及到了一些细节,可看上述代码注释。 - 最后就将,要插入的新 node 插到新数组中某个位置链表头部。
好,至此 ConcurrentHashMap 的 put 方法也就讲解完毕,下面到它的 get 方法...
(4) get 方法
get 方法相对比较简单
public V get(Object key) {
Segment<K,V> s; // manually integrate access methods to reduce overhead
HashEntry<K,V>[] tab;
int h = hash(key); // 1.获取hash值
long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
//2.根据hash值找到对应的 segment
if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null && (tab = s.table) != null) {
//3.从内存中找到segment对象相应位置的数组中的链表(防止并发问题)
for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile(tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE); e != null; e = e.next) {
K k;
if ((k = e.key) == key || (e.hash == h && key.equals(k)))
return e.value;
}
}
return null;
}
(5) size 方法
public int size() {
final Segment < K, V > [] segments = this.segments;
int size;
boolean overflow; // true if size overflows 32 bits
long sum; // 总修改次数
long last = 0 L; // 上一次的总修改次数
int retries = -1; // first iteration isn't retry
try {
for(;;) { //如果遍历次数达到2次以上,证明在第二次遍历时存在并发修改问题,故在第三次遍历时,对每个seg对象加上锁
if(retries++ == RETRIES_BEFORE_LOCK) {//RETRIES_BEFORE_LOCK=2
for(int j = 0; j < segments.length; ++j)
ensureSegment(j).lock();
}
sum = 0 L;
size = 0;
overflow = false;
for(int j = 0; j < segments.length; ++j) {
Segment < K, V > seg = segmentAt(segments, j);
if(seg != null) {
sum += seg.modCount; //得到当前seg对象的修改次数(put、remove)
int c = seg.count; //得到单个seg的大小
if(c < 0 || (size += c) < 0) //标记是否如果产生溢出
overflow = true;
}
}
if(sum == last) //直到和上一次修改总次数得到的总和相等,ConcurrentHashMap 没有被修改过
break;
last = sum;
}
} finally {
if(retries > RETRIES_BEFORE_LOCK) {
for(int j = 0; j < segments.length; ++j)
segmentAt(segments, j).unlock(); //解锁
}
}
return overflow ? Integer.MAX_VALUE : size; //如果发生溢出,返回最大整形
}
- Q1:为什么方法中的 modCount 只增不减,这样设计的目的是什么?
- A1:还是从并发的角度来分析,这样设计的目的是避免一个线程 put 元素和另一个线程 remove 元素后抵消了前面线程的 put 动作的 modCount,进而避免在统计 size 的时候产生死循环问题。
总结一下:
- 前后两次计算出 ConcurrentHashMap 内部的每一个 Segment 对象的 modCount 总和到 sum 和 last 中。
- 如果两次遍历得到的结果不同,即 sum != last 则证明存在线程并发修改,到第三次遍历就会对每一个 seg 对象都加上锁,然后再次遍历,直到 sum = last 退出循环
- 其中需要记录 map 中的元素总数是否发生溢出。
恢复: https://www.javadoop.com/post/hashmap#toc_3 https://www.jianshu.com/p/9c713de7bbdb https://juejin.im/post/5a2f2f7851882554b837823a https://www.cnblogs.com/study-everyday/p/6430462.html#autoid-2-0-0 https://www.cnblogs.com/chengxiao/p/6842045.html