Java7 ForkJoinPool 的使用以及原理

在JDK7中新增了ForkJoinPool。ForkJoinPool採用分治+work-stealing的思想。可以讓我們很方便地將一個大任務拆散成小任務並行地執行提高CPU的使用率

ForkJoinPool & ForkJoinTask 概述:

  • ForkJoinTask:我們要使用 ForkJoin 框架,必須首先創建一個 ForkJoin 任務。它提供在任務中執行 fork() 和 join() 操作的機制,通常情況下我們不需要直接繼承 ForkJoinTask 類,而只需要繼承它的子類,ForkJoin 框架提供了以下兩個子類:
    • RecursiveAction:用於沒有返回結果的任務。
    • RecursiveTask :用於有返回結果的任務。
  • ForkJoinPool :ForkJoinTask 需要通過 ForkJoinPool 來執行,任務分割出的子任務會添加到當前工作線程所維護的雙端隊列中,進入隊列的頭部。當一個工作線程的隊列裏暫時沒有任務時,它會隨機從其他工作線程的隊列的尾部獲取一個任務。

如何充分利用多核CPU,計算很大數組中所有整數的和?

剖析

  • 單線程相加?

    我們最容易想到就是單線程相加,一個for循環搞定。

  • 線程池相加?

    如果進一步優化,我們會自然而然地想到使用線程池來分段相加,最後再把每個段的結果相加。

  • 其它?

    Yes,就是我們今天的主角——ForkJoinPool,但是它要怎麼實現呢?似乎沒怎麼用過哈^^

三種實現

OK,剖析完了,我們直接來看三種實現,不墨跡,直接上菜。

/**

 * 計算1億個整數的和

 */

public class ForkJoinPoolTest01
{
    public static void main(String[] args) throws ExecutionException, InterruptedException {

// 構造數據

        int length = 100000000;
        long[] arr = new long[length];
        for (int i = 0; i < length; i++) {
            arr[i] = ThreadLocalRandom.current().nextInt(Integer.MAX_VALUE);
        }
// 單線程
        singleThreadSum(arr);
// ThreadPoolExecutor線程池
        multiThreadSum(arr);
// ForkJoinPool線程池
        forkJoinSum(arr);
    }

    private static void singleThreadSum(long[] arr) {
        long start = System.currentTimeMillis();
        long sum = 0;
        for (int i = 0; i < arr.length; i++) {
// 模擬耗時
            sum += (arr[i]/ 3 * 3 / 3 * 3 / 3 * 3 / 3 * 3 / 3 * 3);
        }
        System.out.println("sum: " + sum);
        System.out.println("single thread elapse: " + (System.currentTimeMillis() - start));
    }

    private static void multiThreadSum(long[] arr) throws ExecutionException, InterruptedException {
        long start = System.currentTimeMillis();
        int count = 8;
        ExecutorService threadPool = Executors.newFixedThreadPool(count);
        List<Future<Long>> list = new ArrayList<>();
        for(int i = 0; i < count; i++) {
            int num = i;
// 分段提交任務
            Future<Long> future = threadPool.submit(() -> {
                long sum = 0;
                for (int j = arr.length / count * num; j < (arr.length / count * (num + 1)); j++) {
                    try {
// 模擬耗時
                        sum += (arr[j]/ 3 * 3 / 3 * 3 / 3 * 3 / 3 * 3 / 3 * 3);
                    }catch (Exception e) {
                        e.printStackTrace();
                    }
                }
                return sum;
            });
            list.add(future);
        }
// 每個段結果相加
        long sum = 0;
        for(Future<Long> future : list) {
            sum += future.get();
        }
        System.out.println("sum: " + sum);
        System.out.println("multi thread elapse: " + (System.currentTimeMillis() - start));
    }



    private static void forkJoinSum(long[] arr) throws ExecutionException, InterruptedException {
        long start = System.currentTimeMillis();
        ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
// 提交任務
        ForkJoinTask<Long> forkJoinTask = forkJoinPool.submit(
                new SumTask(arr, 0, arr.length));
// 獲取結果
        Long sum = forkJoinTask.get();
        forkJoinPool.shutdown();
        System.out.println("sum: " + sum);
        System.out.println("fork join elapse: " + (System.currentTimeMillis() - start));
    }

    private static class SumTask extends
            RecursiveTask<Long> {
        private long[] arr;
        private int from;
        private int to;
        
        public SumTask(long[] arr, int from, int to) {
            this.arr = arr;
            this.from = from;
            this.to = to;
        }
        
        @Override
        protected Long compute() {
// 小於1000的時候直接相加,可靈活調整
            if (to - from <= 1000) {
                long sum = 0;
                for (int i = from; i < to; i++) {
// 模擬耗時
                    sum += (arr[i]/ 3 * 3 / 3 * 3 / 3 * 3 / 3 * 3 / 3 * 3);

                }
                return sum;
            }
            
// 分成兩段任務

            int middle = (from + to) / 2;
            SumTask left = new SumTask(arr, from, middle);
            SumTask right = new SumTask(arr, middle, to);
// 提交左邊的任務
            left.fork();
// 右邊的任務直接利用當前線程計算,節約開銷
            Long rightResult = right.compute();
// 等待左邊計算完畢
            Long leftResult = left.join();
// 返回結果
            return
                    leftResult + rightResult;
        }
    }
}

如果不加“都 /3*3/3*3/3*3/3*3/3*3了一頓操作” ,實際上計算1億個整數相加,單線程是最快的,我的電腦大概是100ms左右,使用線程池反而會變慢。

所以,爲了演示ForkJoinPool的牛逼之處,把每個數都 /3*3/3*3/3*3/3*3/3*3了一頓操作,用來模擬計算耗時。

來看結果:

sum: 107352457433800662

single thread elapse: 789

sum: 107352457433800662

multi thread elapse: 228

sum: 107352457433800662

fork join elapse: 189

 

可以看到,ForkJoinPool相對普通線程池還是有很大提升的。

分治法

  • 基本思想

    把一個規模大的問題劃分爲規模較小的子問題,然後分而治之,最後合併子問題的解得到原問題的解。

  • 步驟

    (1)分割原問題:

    (2)求解子問題:

    (3)合併子問題的解爲原問題的解。

在分治法中,子問題一般是相互獨立的,因此,經常通過遞歸調用算法來求解子問題。

  • 典型應用場景

    (1)二分搜索

    (2)大整數乘法

    (3)Strassen矩陣乘法

    (4)棋盤覆蓋

    (5)歸併排序

    (6)快速排序

    (7)線性時間選擇

    (8)漢諾塔

ForkJoinPool繼承體系

ForkJoinPool是 java 7 中新增的線程池類,它的繼承體系如下:

ForkJoinPool和ThreadPoolExecutor都是繼承自AbstractExecutorService抽象類,所以它和ThreadPoolExecutor的使用幾乎沒有多少區別,除了任務變成了ForkJoinTask以外。

這裏又運用到了一種很重要的設計原則——開閉原則——對修改關閉,對擴展開放。

可見整個線程池體系一開始的接口設計就很好,新增一個線程池類,不會對原有的代碼造成干擾,還能利用原有的特性。

ForkJoinTask

兩個主要方法

  • fork()

    fork()方法類似於線程的Thread.start()方法,但是它不是真的啓動一個線程,而是將任務放入到工作隊列中。

  • join()

    join()方法類似於線程的Thread.join()方法,但是它不是簡單地阻塞線程,而是利用工作線程運行其它任務。當一個工作線程中調用了join()方法,它將處理其它任務,直到注意到目標子任務已經完成了。

三個子類

  • RecursiveAction

    無返回值任務。

  • RecursiveTask

    有返回值任務。

  • CountedCompleter

    無返回值任務,完成任務後可以觸發回調。

ForkJoinPool內部原理

ForkJoinPool內部使用的是“工作竊取”算法實現的。

(1)每個工作線程都有自己的工作隊列WorkQueue;

(2)這是一個雙端隊列,它是線程私有的;

(3)ForkJoinTask中fork的子任務,將放入運行該任務的工作線程的隊頭,工作線程將以LIFO的順序來處理工作隊列中的任務;

(4)爲了最大化地利用CPU,空閒的線程將從其它線程的隊列中“竊取”任務來執行;

(5)從工作隊列的尾部竊取任務,以減少競爭;

(6)雙端隊列的操作:push()/pop()僅在其所有者工作線程中調用,poll()是由其它線程竊取任務時調用的;

(7)當只剩下最後一個任務時,還是會存在競爭,是通過CAS來實現的;

ForkJoinTask fork 方法

fork() 做的工作只有一件事,既是把任務推入當前工作線程的工作隊列裏。可以參看以下的源代碼:

    public final ForkJoinTask<V> fork() {
        Thread t;
        if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
            ((ForkJoinWorkerThread)t).workQueue.push(this);
        else
            ForkJoinPool.common.externalPush(this);
        return this;
    }

 

ForkJoinTask join 方法

join() 的工作則複雜得多,也是 join() 可以使得線程免於被阻塞的原因——不像同名的 Thread.join()

  1. 檢查調用 join() 的線程是否是 ForkJoinThread 線程。如果不是(例如 main 線程),則阻塞當前線程,等待任務完成。如果是,則不阻塞。
  2. 查看任務的完成狀態,如果已經完成,直接返回結果。
  3. 如果任務尚未完成,但處於自己的工作隊列內,則完成它。
  4. 如果任務已經被其他的工作線程偷走,則竊取這個小偷的工作隊列內的任務(以 FIFO 方式),執行,以期幫助它早日完成欲 join 的任務。
  5. 如果偷走任務的小偷也已經把自己的任務全部做完,正在等待需要 join 的任務時,則找到小偷的小偷,幫助它完成它的任務。
  6. 遞歸地執行第5步。

將上述流程畫成序列圖的話就是這個樣子:

 

ForkJoinPool.submit 方法

    public static void main(String[] args) throws InterruptedException {
        // 創建包含Runtime.getRuntime().availableProcessors()返回值作爲個數的並行線程的ForkJoinPool
        ForkJoinPool forkJoinPool = new ForkJoinPool();
        // 提交可分解的PrintTask任務
        forkJoinPool.submit(new MyRecursiveAction(0, 1000));

        while (!forkJoinPool.isTerminated()) {
            forkJoinPool.awaitTermination(2, TimeUnit.SECONDS);
        }
        // 關閉線程池
        forkJoinPool.shutdown();
    }

其實除了前面介紹過的每個工作線程自己擁有的工作隊列以外,ForkJoinPool 自身也擁有工作隊列,這些工作隊列的作用是用來接收由外部線程(非 ForkJoinThread 線程)提交過來的任務,而這些工作隊列被稱爲 submitting queue 。

submit() 和 fork() 其實沒有本質區別,只是提交對象變成了 submitting queue 而已(還有一些同步,初始化的操作)。submitting queue 和其他 work queue 一樣,是工作線程”竊取“的對象,因此當其中的任務被一個工作線程成功竊取時,就意味着提交的任務真正開始進入執行階段。

 

ForkJoinPool最佳實踐

(1)最適合的是計算密集型任務;

(2)在需要阻塞工作線程時,可以使用ManagedBlocker;

(3)不應該在RecursiveTask的內部使用ForkJoinPool.invoke()/invokeAll();

總結

(1)ForkJoinPool特別適合於“分而治之”算法的實現;

(2)ForkJoinPool和ThreadPoolExecutor是互補的,不是誰替代誰的關係,二者適用的場景不同;

(3)ForkJoinTask有兩個核心方法——fork()和join(),有三個重要子類——RecursiveAction、RecursiveTask和CountedCompleter;

(4)ForkjoinPool內部基於“工作竊取”算法實現;

(5)每個線程有自己的工作隊列,它是一個雙端隊列,自己從隊列頭存取任務,其它線程從尾部竊取任務;

(6)ForkJoinPool最適合於計算密集型任務,但也可以使用ManagedBlocker以便用於阻塞型任務;

(7)RecursiveTask內部可以少調用一次fork(),利用當前線程處理,這是一種技巧;

 

使用場景補充:

 遍歷系統所有文件,得到系統中文件的總數。

思路

通過遞歸的方法。任務在遍歷中如果發現文件夾就創建新的任務讓線程池執行,將返回的文件數加起來,如果發現文件則將計數加一,最終將該文件夾下的文件數返回。

代碼實現

 

    CountingTask countingTask = new CountingTask(Environment.getExternalStorageDirectory());
    forkJoinPool.invoke(countingTask);

    class CountingTask extends RecursiveTask<Integer> {
        private File dir;

        public CountingTask(File dir) {
            this.dir = dir;
        }

        @Override
        protected Integer compute() {
            int count = 0;

            File files[] = dir.listFiles();
            if(files != null){
                for (File f : files){
                    if(f.isDirectory()){
                        // 對每個子目錄都新建一個子任務。
                        CountingTask countingTask = new CountingTask(f);
                        countingTask.fork();
                        count += countingTask.join();

                    }else {
                        Log.d("tag" , "current path = "+f.getAbsolutePath());
                        count++;
                    }
                }
            }


            return count;
        }
    }         

上面的需求,如果我們用普通的線程池該如何完成?

如果我們使用newFixedThreadPool,當核心線程的路徑下都有子文件夾時,它們會將路徑下的子文件夾拋給任務隊列,最終變成所有的核心線程都在等待子文件夾的返回結果,從而造成死鎖。最終任務無法完成。

如果我們使用newCachedThreadPool,依然用上面的思路可以完成任務。但是每次子文件夾就會創建一個新的工作線程,這樣消耗過大。

因此,在這樣的情況下,ForkJoinPool的work-stealing的方式就體現出了優勢。每個任務分配的子任務也由自己執行,只有自己的任務執行完成時,纔會去執行別的工作線程的任務。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章