前言
今天看到了一道面試題
一千萬個數,如何高效求和?
看到這個題中的“高效求和”,第一反應想到了JDK1.8提供的LongAdder
類的設計思想,就是分段求和再彙總。也就是開啓多個線程,每個線程負責計算一部分,所以線程都計算完成後再彙總。整個過程大致如下:
思路已經有了,接下來就開始愉快的編碼吧
測試環境
- win10系統
- 4核4線程CPU
- JDK1.8
- com.google.guava.guava-25.1-jre.jar
- lombok
實例
由於題目對一千萬個數沒有明確定義是什麼數,所以暫定爲int類型的隨機數。爲了對比效率,博主實現了單線程版本和多線程版本,看看多線程到底有多高效。
單線程版本
單線程累加一千萬個數,代碼比較簡單,直接給出
/**
* 單線程的方式累加
* @param arr 一千萬個隨機數
*/
public static int singleThreadSum(int[] arr) {
long start = System.currentTimeMillis();
int sum = 0;
int length = arr.length;
for (int i = 0; i < length; i++) {
sum += arr[i];
}
long end = System.currentTimeMillis();
log.info("單線程方式計算結果:{}, 耗時:{} 秒", sum, (end - start) / 1000.0);
return sum;
}
多線程版本
多線程的版本涉及到線程池(開啓多個線程)、CountDownLatch(主線程等待子線程執行完成)等工具的使用,所以稍微複雜一些。
// 每個task求和的規模
private static final int SIZE_PER_TASK = 200000;
// 線程池
private static ThreadPoolExecutor executor = null;
static {
// 核心線程數 CPU數量 + 1
int corePoolSize = Runtime.getRuntime().availableProcessors() + 1;
executor = new ThreadPoolExecutor(corePoolSize, corePoolSize, 3, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>());
}
/**
* 多線程的方式累加
*
* @param arr 一千萬個隨機數
* @throws InterruptedException
*/
public static int concurrencySum(int[] arr) throws InterruptedException {
long start = System.currentTimeMillis();
LongAdder sum = new LongAdder();
// 拆分任務
List<List<int[]>> taskList = Lists.partition(Arrays.asList(arr), SIZE_PER_TASK);
// 任務總數
final int taskSize = taskList.size();
final CountDownLatch latch = new CountDownLatch(taskSize);
for (int i = 0; i < taskSize; i++) {
int[] task = taskList.get(i).get(0);
executor.submit(() -> {
try {
for (int num : task) {
// 把每個task中的數字累加
sum.add(num);
}
} finally {
// task執行完成後,計數器減一
latch.countDown();
}
});
}
// 主線程等待所有子線程執行完成
latch.await();
long end = System.currentTimeMillis();
log.info("多線程方式計算結果:{}, 耗時:{} 秒", sum, (end - start) / 1000.0);
// 關閉線程池
executor.shutdown();
return sum.intValue();
}
由於代碼中有了詳細的註釋,所以不再贅述。
main方法
main方法也比較簡單,主要產生1千萬個隨機數,再調用兩個方法即可。
// 求和的個數
private static final int SUM_COUNT = 10000000;
public static void main(String[] args) throws InterruptedException {
Random random = new Random();
int[] arr = new int[SUM_COUNT];
for (int i = 0; i < SUM_COUNT; i++) {
arr[i] = random.nextInt(200);
}
// 多線程版本
concurrencySum(arr);
// 單線程版本
singleThreadSum(arr);
}
第8行代碼random.nextInt(200)
爲什麼是200?
因爲 1kw * 200 = 20 億 < Integer.MAX_VALUE,所以累加結果不會溢出
終於到了測試效率的時候了,是騾子是馬,拉出來溜溜。
信心滿滿的我,點擊了run,得到了如下結果
22:13:31.068 [main] INFO com.sicimike.concurrency.EfficientSum - 多線程方式計算結果:995523090, 耗時:0.133 秒
22:13:31.079 [main] INFO com.sicimike.concurrency.EfficientSum - 單線程方式計算結果:995523090, 耗時:0.006 秒
可能是我打開的方式不對…
但是
經過了多次運行,以及調整線程池參數之後的多次運行,總是得出不忍直視的運行結果。
多線程方式運行時間穩定在0.130秒左右,單線程運行方式穩定在0.006秒左右。
多線程改進
前文多線程的版本中使用了LongAdder
類,由於LongAdder
類在底層使用了大量的cas操作,線程競爭非常激烈時,效率會有不同程度的降低。所以在改進本例中多線程的版本時,不使用LongAdder
類,而是更適合當前場景的方式。
/**
* 多線程的方式累加(改進版)
*
* @param arr 一千萬個隨機數
* @throws InterruptedException
*/
public static int concurrencySum(int[] arr) throws InterruptedException {
long start = System.currentTimeMillis();
int sum = 0;
// 拆分任務
List<List<int[]>> taskList = Lists.partition(Arrays.asList(arr), SIZE_PER_TASK);
// 任務總數
final int taskSize = taskList.size();
final CountDownLatch latch = new CountDownLatch(taskSize);
// 相當於LongAdder中的Cell[]
int[] result = new int[taskSize];
for (int i = 0; i < taskSize; i++) {
int[] task = taskList.get(i).get(0);
final int index = i;
executor.submit(() -> {
try {
for (int num : task) {
// 各個子線程分別執行累加操作
// result每一個單元就是一個task的累加結果
result[index] += num;
}
} finally {
latch.countDown();
}
});
}
// 等待所有子線程執行完成
latch.await();
for (int i : result) {
// 把子線程執行的結果累加起來就是最終的結果
sum += i;
}
long end = System.currentTimeMillis();
log.info("多線程方式計算結果:{}, 耗時:{} 秒", sum, (end - start) / 1000.0);
// 關閉線程池
executor.shutdown();
return sum;
}
執行改進後的方法,得到如下結果:
22:46:05.085 [main] INFO com.sicimike.concurrency.EfficientSum - 多線程方式計算結果:994958790, 耗時:0.049 秒
22:46:05.094 [main] INFO com.sicimike.concurrency.EfficientSum - 單線程方式計算結果:994958790, 耗時:0.006 秒
多次運行,以及調整線程池參數之後的多次運行,結果也趨於穩定。
多線程方式運行時間穩定在0.049秒左右,單線程運行方式穩定在0.006秒左右
從0.133秒到0.049秒,效率大概提升了170%
思考
改進後的代碼不僅沒有解決單線程爲什麼比多線程快的問題,反而還多了一個問題:
爲什麼隨隨便便引入一個數組,竟然比Doug Lea寫的
LongAdder
還快?
因爲LongAdder
是一個通用的工具類,很好的平衡了時間和空間的關係,所以在各種場景下都能有較好的效率。而本例中的result數組,一千萬個數字被分成了多少個task,數組的長度就是多少,每個task的結果都存在獨立的數組項,不存在競爭,但是佔用了更多的空間,所以時間效率更高,也就是拿空間換時間的思想。
至於爲什麼單線程比多線程快,這其實並不難解釋。因爲單線程沒有上下文切換,加上累加場景比較簡單,每個task執行時間很短,所以單線程更快很正常。
stream方式
stream
是JDK1.8提供的語法糖,也是單線程的。關於stream
的用法,大家自行了解即可。主要用來和後文的parallel stream
進行對比。
public static int streamSum(List<Integer> list) {
long start = System.currentTimeMillis();
int sum = list.stream().mapToInt(num -> num).sum();
long end = System.currentTimeMillis();
log.info("stream方式計算結果:{}, 耗時:{} 秒", sum, (end - start) / 1000.0);
return sum;
}
parallelStream方式
parallelStream
見名知意,就是並行的stream
。
public static int parallelStreamSum(List<Integer> list) {
long start = System.currentTimeMillis();
int sum = list.parallelStream().mapToInt(num -> num).sum();
long end = System.currentTimeMillis();
log.info("parallel stream方式計算結果:{}, 耗時:{} 秒", sum, (end - start) / 1000.0);
return sum;
}
ForkJoin方式
ForkJoin
框架是JDK1.7提出的,用於拆分任務計算再合併計算結果的框架。
當我們需要執行大量的小任務時,有經驗的Java開發人員都會採用線程池來高效執行這些小任務。然而,有一種任務,例如,對超過1000萬個元素的數組進行排序,這種任務本身可以併發執行,但如何拆解成小任務需要在任務執行的過程中動態拆分。這樣,大任務可以拆成小任務,小任務還可以繼續拆成更小的任務,最後把任務的結果彙總合併,得到最終結果,這種模型就是Fork/Join模型。
ForkJoin
框架的使用大致分爲兩個部分:實現ForkJoin任務、執行任務
實現ForkJoin任務
自定義類繼承RecursiveTask
(有返回值)或者RecursiveAction
(無返回值),實現compute
方法
/**
* 靜態內部類的方式實現
* forkjoin任務
*/
static class SicForkJoinTask extends RecursiveTask<Integer> {
// 子任務計算區間開始
private Integer left;
// 子任務計算區間結束
private Integer right;
private int[] arr;
@Override
protected Integer compute() {
if (right - left < SIZE_PER_TASK) {
// 任務足夠小時,直接計算
int sum = 0;
for (int i = left; i < right; i++) {
sum += arr[i];
}
return sum;
}
// 繼續拆分任務
int middle = left + (right - left) / 2;
SicForkJoinTask leftTask = new SicForkJoinTask(arr, left, middle);
SicForkJoinTask rightTask = new SicForkJoinTask(arr, middle, right);
invokeAll(leftTask, rightTask);
Integer leftResult = leftTask.join();
Integer rightResult = rightTask.join();
return leftResult + rightResult;
}
public SicForkJoinTask(int[] arr, Integer left, Integer right) {
this.arr = arr;
this.left = left;
this.right = right;
}
}
執行任務
通過ForkJoinPool
的invoke
方法執行ForkJoin
任務
// ForkJoin線程池
private static final ForkJoinPool forkJoinPool = new ForkJoinPool();
public static int forkJoinSum(int[] arr) {
long start = System.currentTimeMillis();
// 執行ForkJoin任務
Integer sum = forkJoinPool.invoke(new SicForkJoinTask(arr, 0, SUM_COUNT));
long end = System.currentTimeMillis();
log.info("forkjoin方式計算結果:{}, 耗時:{} 秒", sum, (end - start) / 1000.0);
return sum;
}
main方法
public static void main(String[] args) throws InterruptedException {
Random random = new Random();
int[] arr = new int[SUM_COUNT];
List<Integer> list = new ArrayList<>(SUM_COUNT);
int currNum = 0;
for (int i = 0; i < SUM_COUNT; i++) {
currNum = random.nextInt(200);
arr[i] = currNum;
list.add(currNum);
}
// 單線程執行
singleThreadSum(arr);
// Executor線程池執行
concurrencySum(arr);
// stream執行
streamSum(list);
// 並行stream執行
parallelStreamSum(list);
// forkjoin線程池執行
forkJoinSum(arr);
}
執行結果
23:19:21.207 [main] INFO com.sicimike.concurrency.EfficientSum - 單線程方式計算結果:994917205, 耗時:0.006 秒
23:19:21.274 [main] INFO com.sicimike.concurrency.EfficientSum - 多線程方式計算結果:994917205, 耗時:0.062 秒
23:19:21.292 [main] INFO com.sicimike.concurrency.EfficientSum - stream方式計算結果:994917205, 耗時:0.018 秒
23:19:21.309 [main] INFO com.sicimike.concurrency.EfficientSum - parallel stream方式計算結果:994917205, 耗時:0.017 秒
23:19:21.321 [main] INFO com.sicimike.concurrency.EfficientSum - forkjoin方式計算結果:994917205, 耗時:0.012 秒
源代碼
代碼地址:EfficientSum.java
有興趣的同學可以自己下載源代碼後,調整各個參數運行,得到的結果不一定和我一樣。
總結
代碼寫了一大版,結果最初的問題還是沒解決。有人可能會說:博主你坑爹呢。
確實,我沒有想到更好的辦法,但是把文中的幾個問題想清楚,應該會比一道面試題更有價值。
如果哪位同學有更好的優化方式,還請不吝賜教。