Fork/Join框架是Java7提供的一個用於並行執行任務的框架,是一個把大任務分割成若干個小任務,最終彙總每個小任務結果後得到大任務結果的框架。
Fork就是把一個大任務切分爲若干子任務並行的執行,Join就是合併這些子任務的執行結果,最後得到這個大任務的結果。比如處理100個任務,可以分割成20個子任務,每個子任務分別處理5個,最終彙總這20個子任務的結果。
工作竊取算法
工作竊取(work-stealing)算法是指某個線程從其他隊列裏竊取任務來執行。那麼,爲什麼需要使用工作竊取算法呢?假如我們需要做一個比較大的任務,
可以把這個任務分割爲若干互不依賴的子任務,爲了減少線程間的競爭,把這些子任務分別放到不同的隊列裏,併爲每個隊列創建一個單獨的線程來執行隊列裏的任務,
線程和隊列一一對應。比如A線程負責處理A隊列裏的任務。但是,有的線程會先把自己隊列裏的任務幹完,而其他線程對應的隊列裏還有任務等待處理。
幹完活的線程與其等着,不如去幫其他線程幹活,於是它就去其他線程的隊列裏竊取一個任務來執行。而在這時它們會訪問同一個隊列,
所以爲了減少竊取任務線程和被竊取任務線程之間的競爭,通常會使用雙端隊列,被竊取任務線程永遠從雙端隊列的頭部拿任務執行,
而竊取任務的線程永遠從雙端隊列的尾部拿任務執行。
工作竊取算法的優點:充分利用線程進行並行計算,減少了線程間的競爭。
工作竊取算法的缺點:在某些情況下還是存在競爭,比如雙端隊列裏只有一個任務時。並且該算法會消耗了更多的系統資源,比如創建多個線程和多個雙端隊列。
ForkJoin框架的設計
分割任務
首先我們需要有一個fork類來把大任務分割成子任務,有可能子任務還是很大,所以還需要不停地分割,直到分割出的子任務足夠小
執行任務併合並結果
分割的子任務分別放在雙端隊列裏,然後幾個啓動線程分別從雙端隊列裏獲取任務執行。子任務執行完的結果都統一放在一個隊列裏,啓動一個線程從隊列裏拿數據,然後合併這些數據
Fork/Join使用兩個類來完成以上兩件事情
1. ForkJoinTask:我們要使用ForkJoin框架,必須首先創建一個ForkJoin任務。它提供在任務中執行fork()和join()操作的機制。通常情況下,我們不需要直接繼承ForkJoinTask類,只需要繼承它的子類,Fork/Join框架提供了以下兩個子類
- RecursiveAction:用於沒有返回結果的任務
- RecursiveTask:用於有返回結果的任務
2. ForkJoinPool:ForkJoinTask需要通過ForkJoinPool來執行。任務分割出的子任務會添加到當前工作線程所維護的雙端隊列中,進入隊列的頭部。當一個工作線程的隊列裏暫時沒有任務時,它會隨機從其他工作線程的隊列的尾部獲取一個任務
ForkJoin Demo演示
假設有個複雜的批量自動化任務要分割爲單個子任務去執行,跑完全部子任務後要彙總每一個任務的結果到一個集合中統一返回
自動化任務類
package com.brian.mutilthread.forkjoin.service; import lombok.extern.slf4j.Slf4j; import java.util.*; import java.util.concurrent.RecursiveTask; @Slf4j public class AutomationTask extends RecursiveTask<List<Map<String, String>>> { static List<Map<String, String>> resultList; private List<String> list; private int start; private int end; static { resultList = new ArrayList<>(8); } public AutomationTask(List<String> list, int start, int end) { this.list = list; this.start = start; this.end = end; if (resultList.size() >= 8) { resultList.clear(); } } @Override protected List<Map<String, String>> compute() { if ((end - start) < 1) { log.info("=== {} === {}-{}", Thread.currentThread().getName(), start, list.get(start)); Map<String, String> result = new HashMap<>(); result.put("region", list.get(start)); try {
// 任務2的定義 ServiceTask serviceTask = uuid -> { int rad = (int) (Math.random() * 100); if (rad > 80) { throw new Exception("getTransitions exception"); } Thread.sleep(rad); return UUID.randomUUID().toString().replace("-", ""); }; // step 1 String parameter = serviceTask.getParameter(list.get(start)); // step 2 String task = serviceTask.createTask(parameter); // step 3 String transitions = serviceTask.getTransitions(task); // step 4 String s = serviceTask.transferStatus(transitions); result.put("status", s); } catch (Exception e) { result.put("error", e.toString()); } resultList.add(result); } else { int middle = (start + end) / 2; AutomationTask leftTask = new AutomationTask(list, start, middle); log.info("=== {} === fork left {}-{}", Thread.currentThread().getName(), start, middle); AutomationTask rightTask = new AutomationTask(list, middle + 1, end); log.info("=== {} === fork right {}-{}", Thread.currentThread().getName(), middle + 1, end); leftTask.fork(); rightTask.compute(); List<Map<String, String>> leftList = leftTask.join(); log.info("=== {} === leftList {}", Thread.currentThread().getName(), leftList.toArray()); } return resultList; } }
複雜的業務類
模擬一個複雜的業務,中途可能會出現異常
package com.brian.mutilthread.forkjoin.service; import java.util.UUID; @FunctionalInterface public interface ServiceTask { // 1 default String getParameter(String region) throws Exception { int rad = (int) (Math.random() * 100); if (rad > 80) { throw new Exception("getParameter exception"); } Thread.sleep(rad); return UUID.randomUUID().toString().replace("-", ""); } // 2 String createTask(String uuid) throws Exception; // 3 default String getTransitions(String jid) throws Exception { int rad = (int) (Math.random() * 100); if (rad > 80) { throw new Exception("getTransitions exception"); } Thread.sleep(rad); return UUID.randomUUID().toString().replace("-", ""); } // 4 default String transferStatus(String tid) throws Exception { int rad = (int) (Math.random() * 100); if (rad > 80) { throw new Exception("transferStatus exception"); } Thread.sleep(rad); return UUID.randomUUID().toString().replace("-", ""); } }
測試類
package com.brian.mutilthread.forkjoin.controller; import com.brian.mutilthread.forkjoin.service.AutomationTask; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.util.StopWatch; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; import reactor.core.publisher.Flux; import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.Future; import java.util.stream.Collectors; @RestController @Slf4j public class ForkJoinTestController { @Autowired public ForkJoinPool forkJoinPool; @GetMapping("/getForkJoinResult") public Flux<Map<String, String>> getForkJoinResult(@RequestParam(value = "country", defaultValue = "CN,HK,JP,KR,SG,TH,TW,ER") String country) throws ExecutionException, InterruptedException { String[] countries = country.split(","); if (countries.length < 8) { return Flux.empty(); } StopWatch stopWatch = new StopWatch(); stopWatch.start(); AutomationTask computeTask = new AutomationTask(Arrays.stream(countries).collect(Collectors.toList()), 0, countries.length - 1); Future<List<Map<String, String>>> results = forkJoinPool.submit(computeTask); if(computeTask.isCompletedAbnormally()){ log.info("<><><><><><><><><><> automationTask exception: {}", computeTask.getException()); } List<Map<String, String>> res = results.get(); log.info(">>>>>>>>>>>>>>>>>>>>result size : {}", res.size()); Flux<Map<String, String>> mapFlux = Flux.fromIterable(res); stopWatch.stop(); log.info(">>>>>>>>>>>total handle time: {} ms", stopWatch.getTotalTimeMillis()); return mapFlux; } }
ForkJoinTask在執行的時候可能會拋出異常,但是我們沒辦法在主線程裏直接捕獲異常,所以ForkJoinTask提供了isCompletedAbnormally()方法來檢查任務是否已經拋出異常或已經被取消,並且可以通過ForkJoinTask的getException方法獲取異常。如上面的測試類中有個如下的代碼片段
if(computeTask.isCompletedAbnormally()){ log.info("<><><><><><><><><><> automationTask exception: {}", computeTask.getException()); }
getException方法返回Throwable對象,如果任務被取消了則返回CancellationException。如果任務沒有完成或者沒有拋出異常則返回null
ForkJoin框架的原理
ForkJoinPool類
//繼承AbstractExecutorService 類 public class ForkJoinPool extends AbstractExecutorService{ //任務隊列數組,存儲了所有任務隊列,包括內部隊列和外部隊列 volatile WorkQueue[] workQueues; // main registry //一個靜態常量,ForkJoinPool 提供的內部公用的線程池 static final ForkJoinPool common; //默認的線程工廠類 public static final ForkJoinWorkerThreadFactory defaultForkJoinWorkerThreadFactory; }
ForkJoinWorkerThread類
//繼承Thread 類 public class ForkJoinWorkerThread extends Thread { //線程工作的線程池,即此線程所屬的線程池 final ForkJoinPool pool; // 線程的內部隊列 final ForkJoinPool.WorkQueue workQueue; //..... }
ForkJoinPool中線程的創建
默認的線程工廠類,ForkJoinPool 中的線程是由默認的線程工廠類 defaultForkJoinWorkerThreadFactory
創建的
//默認的工廠類 public static final ForkJoinWorkerThreadFactory defaultForkJoinWorkerThreadFactory; defaultForkJoinWorkerThreadFactory = new DefaultForkJoinWorkerThreadFactory();
defaultForkJoinWorkerThreadFactory
創建線程的方法 newThread()
,其實就是傳入當前的線程池,直接創建
/** * Default ForkJoinWorkerThreadFactory implementation; creates a * new ForkJoinWorkerThread using the system class loader as the * thread context class loader. */ private static final class DefaultForkJoinWorkerThreadFactory implements ForkJoinWorkerThreadFactory { private static final AccessControlContext ACC = contextWithPermissions( new RuntimePermission("getClassLoader"), new RuntimePermission("setContextClassLoader")); public final ForkJoinWorkerThread newThread(ForkJoinPool pool) { return AccessController.doPrivileged( new PrivilegedAction<>() { public ForkJoinWorkerThread run() { return new ForkJoinWorkerThread( pool, ClassLoader.getSystemClassLoader()); }}, ACC); } }
ForkJoinWorkerThread 的構造方法
protected ForkJoinWorkerThread(ForkJoinPool pool) { // Use a placeholder until a useful name can be set in registerWorker super("aForkJoinWorkerThread"); //線程工作的線程池,即創建這個線程的線程池 this.pool = pool; //註冊線程到線程池中,並返回此線程的內部任務隊列 this.workQueue = pool.registerWorker(this); }
創建一個工作線程,最後一步還要註冊到其所屬的線程池中, registerWorker這裏不展開了
ForkJoinTask的fork()方法
public final ForkJoinTask<V> fork() { Thread t; //判斷是否是一個工作線程 if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) //加入到內部隊列中 ((ForkJoinWorkerThread)t).workQueue.push(this); else//由common線程池來執行任務 ForkJoinPool.common.externalPush(this); return this; }
fork()
方法先判斷當前線程(調用fork()
來提交任務的線程)是不是一個 ForkJoinWorkerThread
的工作線程,如果是,則將任務加入到內部隊列中,否則,由 ForkJoinPool
提供的內部公用的線程池common線程池
來執行這個任務。我們可以在普通線程池中直接調用 fork()
方法來提交任務到一個默認提供的線程池中。這將非常方便。假如,你要在程序中處理大任務,需要分治編程,但你僅僅只處理一次,以後就不會用到,而且任務不算太大,不需要設置特定的參數,那麼你肯定不想爲此創建一個線程池,這時默認的提供的線程池將會很有用。
ForkJoinTask的join()方法
public final V join() { int s; if ((s = doJoin() & DONE_MASK) != NORMAL) reportException(s); return getRawResult();//直接返回結果 }
private int doJoin() { int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w; return //如果完成,直接返回s (s = status) < 0 ? s : //沒有完成,判斷是不是池中的 ForkJoinWorkerThread 工作線程 ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ? //如果是池中線程,執行這裏 (w = (wt = (ForkJoinWorkerThread)t).workQueue). tryUnpush(this) && (s = doExec()) < 0 ? s : wt.pool.awaitJoin(w, this, 0L) : //如果不是池中的線程池,則執行這裏 externalAwaitDone(); }
join()方法有執行一個重要的方法doJoin(), 當dojoin()
方法發現任務沒有完成且當前線程是池中線程時,執行了 tryUnpush()
方法。tryUnpush()
方法嘗試去執行此任務:如果要join的任務正好在當前任務隊列的頂端,那麼pop出這個任務,然後調用 doExec() 讓當前線程去執行這個任務
final boolean tryUnpush(ForkJoinTask<?> t) { ForkJoinTask<?>[] a; int s; if ((a = array) != null && (s = top) != base && U.compareAndSwapObject (a, (((a.length - 1) & --s) << ASHIFT) + ABASE, t, null)) { U.putOrderedInt(this, QTOP, s); return true; } return false; }
final int doExec() { int s; boolean completed; if ((s = status) >= 0) { try { completed = exec(); } catch (Throwable rex) { return setExceptionalCompletion(rex); } if (completed) s = setCompletion(NORMAL); } return s; }
如果任務不是處於隊列的頂端,那麼就會執行 awaitJoin()
方法
final int awaitJoin(WorkQueue w, ForkJoinTask<?> task, long deadline) { int s = 0; if (task != null && w != null) { ForkJoinTask<?> prevJoin = w.currentJoin; U.putOrderedObject(w, QCURRENTJOIN, task); CountedCompleter<?> cc = (task instanceof CountedCompleter) ? (CountedCompleter<?>)task : null; for (;;) { if ((s = task.status) < 0)//如果任務完成了,跳出死循環 break; if (cc != null)//當前任務是CountedCompleter類型,則嘗試從任務隊列中獲取當前任務的派生子任務來執行; helpComplete(w, cc, 0); else if (w.base == w.top || w.tryRemoveAndExec(task))//如果當前線程的內部隊列爲空,或者成功完成了任務,幫助某個線程完成任務。 helpStealer(w, task); if ((s = task.status) < 0)//任務完成,跳出死循環 break; long ms, ns; if (deadline == 0L) ms = 0L; else if ((ns = deadline - System.nanoTime()) <= 0L) break; else if ((ms = TimeUnit.NANOSECONDS.toMillis(ns)) <= 0L) ms = 1L; if (tryCompensate(w)) { task.internalWait(ms); U.getAndAddLong(this, CTL, AC_UNIT); } } U.putOrderedObject(w, QCURRENTJOIN, prevJoin); } return s; }
/** * Tries to locate and execute tasks for a stealer of the given * task, or in turn one of its stealers, Traces currentSteal -> * currentJoin links looking for a thread working on a descendant * of the given task and with a non-empty queue to steal back and * execute tasks from. The first call to this method upon a * waiting join will often entail scanning/search, (which is OK * because the joiner has nothing better to do), but this method * leaves hints in workers to speed up subsequent calls. * * @param w caller * @param task the task to join */ private void helpStealer(WorkQueue w, ForkJoinTask<?> task) { WorkQueue[] ws = workQueues; int oldSum = 0, checkSum, m; if (ws != null && (m = ws.length - 1) >= 0 && w != null && task != null) { do { // restart point checkSum = 0; // for stability check ForkJoinTask<?> subtask; WorkQueue j = w, v; // v is subtask stealer descent: for (subtask = task; subtask.status >= 0; ) { for (int h = j.hint | 1, k = 0, i; ; k += 2) { if (k > m) // can't find stealer break descent; if ((v = ws[i = (h + k) & m]) != null) { if (v.currentSteal == subtask) { j.hint = i; break; } checkSum += v.base; } } for (;;) { // help v or descend ForkJoinTask<?>[] a; int b; checkSum += (b = v.base); ForkJoinTask<?> next = v.currentJoin; if (subtask.status < 0 || j.currentJoin != subtask || v.currentSteal != subtask) // stale break descent; if (b - v.top >= 0 || (a = v.array) == null) { if ((subtask = next) == null) break descent; j = v; break; } int i = (((a.length - 1) & b) << ASHIFT) + ABASE; ForkJoinTask<?> t = ((ForkJoinTask<?>) U.getObjectVolatile(a, i)); if (v.base == b) { if (t == null) // stale break descent; if (U.compareAndSwapObject(a, i, t, null)) { v.base = b + 1; ForkJoinTask<?> ps = w.currentSteal; int top = w.top; do { U.putOrderedObject(w, QCURRENTSTEAL, t); t.doExec(); // clear local tasks too } while (task.status >= 0 && w.top != top && (t = w.pop()) != null); U.putOrderedObject(w, QCURRENTSTEAL, ps); if (w.base != w.top) return; // can't further help } } } } } while (task.status >= 0 && oldSum != (oldSum = checkSum)); } }
上面的helpStealer()方法,原則是你幫助我執行任務,我也幫你執行任務。
1.遍歷奇數下標,如果發現隊列對象currentSteal放置的剛好是自己要找的任務,則說明自己的任務被該隊列a的owner線程偷來執行 2.如果隊列a隊列中有任務,則從隊尾(base)取出執行; 3.如果發現隊列b隊列爲空,則根據它正在join的任務,在拓撲找到相關的隊列B去偷取任務執行。在執行的過程中要注意,我們應該完整的把任務完成
參考鏈接:
2. Fork/Join 框架-設計與實現(翻譯自論文《A Java Fork/Join Framework》原作者 Doug Lea)