Java SynchronizedSet 線程不安全之坑

一、前言

一般而言,想要構造出線程安全的 Set,我們會使用 Collections.synchronizedSet 方法,如下所示。

Set<User> set = Collections.synchronizedSet(new HashSet<>());

但這並不意味着,你可以安全的使用該集合的任何方法,如果沒有仔細的瞭解過其實現的話,一不小心就會踩進坑中。

最近我在使用該集合的 stream 方法時發現了線程不安全問題,都是血的教訓啊,下面寫個Case 來複現下吧。

二、問題引出

2.1 輔助類

本 Case 牽扯到的所有輔助類如下:

public class ThreadPoolUtils {
    private static final long KEEP_ALIVE_TIME = 60L;

    private static Logger log = LogManager.getLogger(ThreadPoolUtils.class);

    public static ThreadPoolExecutor poolExecutor(int core, int max, Object... name) {
        ThreadFactory factory = Objects.nonNull(name) ?
                new ThreadFactoryBuilder().setNameFormat(Joiner.on(" ").join(name)).build() :
                new ThreadFactoryBuilder().build();

        return new ThreadPoolExecutor(core, max, KEEP_ALIVE_TIME, TimeUnit.SECONDS, new LinkedBlockingQueue<>(), factory,
                new ThreadPoolExecutor.AbortPolicy());
    }

    public static void sleep(long timeout, TimeUnit unit) {
        try {
            unit.sleep(timeout);
        } catch (InterruptedException e) {
            log.info("ThreadPoolUtils#sleep error, timeout: {}", timeout, e);
        }
    }
}
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class User {
    private long userId;

    private long timestamp;

    @Override
    public boolean equals(Object object) {
        if(object instanceof User) {
            User user = (User)object;
            return user.getUserId() == ((User) object).getUserId();
        }
        return false;
    }

    @Override
    public int hashCode() {
        int result = 17;
        result = 31 * result + (int) (userId ^ (userId >>> 32));
        result = 31 * result + (int) (timestamp ^ (timestamp >>> 32));
        return result;
    }
}

2.2 測試類

Case 想要達到如下效果:

  1. 線程 A 不停地往 Set 中添加元素。
  2. 線程 B 不停地對 Set 做 Stream 操作。
  3. 線程 B 在 Stream 的執行過程中間,線程 A 必須要進行添加操作。

爲了達到這個效果,線程 B 在 Stream 過程中,增加了 Filter 並在其中 Sleep 10ms,確保在這段 Sleep 過程中,線程 A 會進行添加操作。

public class SynchronizedSetTest {
    private Set<User> set = Collections.synchronizedSet(new HashSet<>());
    private static Logger log = LogManager.getLogger(SynchronizedSetTest.class);

    public static void main(String[] args) {
        new SynchronizedSetTest().testStream();
    }

    public void testStream() {
        ThreadPoolExecutor executor = ThreadPoolUtils.poolExecutor(2, 2, "synchronizedSet-test-pool");
        executor.execute(this::add);
        executor.execute(this::stream);
    }

    public void add() {
        while (true) {
            int size = RandomUtils.nextInt(1, 10);
            IntStream.range(0, size).forEach(e -> {
                set.add(random());
                log.info("SynchronizedSetTest#add size: {}", set.size());
                ThreadPoolUtils.sleep(10, TimeUnit.MILLISECONDS);
            });
        }
    }

    public void stream() {
        while (true) {
            List<User> userList = set.stream()
                    .filter(e -> {
                        ThreadPoolUtils.sleep(10, TimeUnit.MILLISECONDS);
                        log.info("SynchronizedSetTest#stream filter check...");
                        return System.currentTimeMillis() - e.getTimestamp() > 30L;
                    }).collect(Collectors.toList());
            log.info("SynchronizedSetTest#stream heartBeat...");
        }
    }

    private User random() {
        return User.builder().userId(RandomUtils.nextLong(1, 100000)).timestamp(System.currentTimeMillis()).build();
    }
}

運行程序,剛運行就拋錯了:

00:33:28.179 [synchronizedSet-test-pool] INFO  jit.wxs.SynchronizedSetTest - SynchronizedSetTest#add size: 1
00:33:28.188 [synchronizedSet-test-pool] INFO  jit.wxs.SynchronizedSetTest - SynchronizedSetTest#stream filter check...
00:33:28.189 [synchronizedSet-test-pool] INFO  jit.wxs.SynchronizedSetTest - SynchronizedSetTest#stream heartBeat...
00:33:28.191 [synchronizedSet-test-pool] INFO  jit.wxs.SynchronizedSetTest - SynchronizedSetTest#add size: 2
00:33:28.199 [synchronizedSet-test-pool] INFO  jit.wxs.SynchronizedSetTest - SynchronizedSetTest#stream filter check...
00:33:28.201 [synchronizedSet-test-pool] INFO  jit.wxs.SynchronizedSetTest - SynchronizedSetTest#add size: 3
00:33:28.209 [synchronizedSet-test-pool] INFO  jit.wxs.SynchronizedSetTest - SynchronizedSetTest#stream filter check...
00:33:28.211 [synchronizedSet-test-pool] INFO  jit.wxs.SynchronizedSetTest - SynchronizedSetTest#add size: 4
00:33:28.219 [synchronizedSet-test-pool] INFO  jit.wxs.SynchronizedSetTest - SynchronizedSetTest#stream filter check...
00:34:55.316 [synchronizedSet-test-pool] INFO  jit.wxs.SynchronizedSetTest - SynchronizedSetTest#add size: 5
00:34:55.327 [synchronizedSet-test-pool] INFO  jit.wxs.SynchronizedSetTest - SynchronizedSetTest#add size: 6
Exception in thread "synchronizedSet-test-pool" java.util.ConcurrentModificationException
	at java.util.HashMap$KeySpliterator.forEachRemaining(HashMap.java:1561)
	at java.util.stream.AbstractPipeline.copyInto(AbstractPipeline.java:482)
	at java.util.stream.AbstractPipeline.wrapAndCopyInto(AbstractPipeline.java:472)
	at java.util.stream.ReduceOps$ReduceOp.evaluateSequential(ReduceOps.java:708)
	at java.util.stream.AbstractPipeline.evaluate(AbstractPipeline.java:234)
	at java.util.stream.ReferencePipeline.collect(ReferencePipeline.java:499)
	at jit.wxs.SynchronizedSetTest.stream(SynchronizedSetTest.java:54)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)
Disconnected from the target VM, address: '127.0.0.1:62588', transport: 'socket'
00:34:55.340 [synchronizedSet-test-pool] INFO  jit.wxs.SynchronizedSetTest - SynchronizedSetTest#add size: 7

ConcurrentModificationException 這個異常如果對集合比較瞭解的話,是很熟悉的的。當我們對 ArrayList 迭代過程中進行添加/刪除操作,就會報這個錯誤,錯誤原因就是 Collection 底層的 modCount 導致的。

下面描述下兩個線程剛剛的執行情況:

  1. 線程 A 開始添加元素,添加第一個,此時集合大小爲 1。
  2. 線程 B 開始 Stream 操作,執行到 Filter 時,被 Sleep 住。
  3. 線程 A 在線程 B Sleep 期間,一直添加元素。
  4. 線程 B Filter 執行完畢,執行最後 Collect() 操作。
  5. 根據上面的異常棧,得知線程 B Collect() 最後會調用 HashMap 的 forEachRemaining 方法。

之所以調用 HashMap,是因爲 HashSet 就是 HashMap 的一種特殊實現。

java.util.HashMap.KeySpliterator#forEachRemaining 進行 Debug,如下圖所示。此時 Set 的 modCount 已經更新到了 4(這是沒問題的,因爲線程 A 一直在添加,添加了 4 次),然而線程 B 的 mc仍然爲開始 Stream 時的 1,因此拋出了異常。

三、源碼查看

Collections.synchronizedSet 創建了 SynchronizedSet對象,構造方法又調用了父類 SynchronizedCollection

看到 SynchronizedCollection 後就一切都明白了。首先把當前對象作爲同步對象,因此加了對象鎖的方法都是線程安全的,沒有加 synchronized 修飾的方法就都是非線程安全的,使用過程中必須手動加同步塊:

  • itearator()
  • spliterator()
  • stream()
  • parallelStream()

這邊有一個有意思的地方,forEach()是線程安全,而 itearator() 不是線程安全,stream().forEach() 也不是線程安全的。不同的遍歷方式線程安全與否也不一樣,不太明白 JDK 是怎麼考慮這樣設計的。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章