一千萬個數高效求和

前言

今天看到了一道面試題

一千萬個數,如何高效求和?

看到這個題中的“高效求和”,第一反應想到了JDK1.8提供的LongAdder類的設計思想,就是分段求和再彙總。也就是開啓多個線程,每個線程負責計算一部分,所以線程都計算完成後再彙總。整個過程大致如下:
高效求和

思路已經有了,接下來就開始愉快的編碼吧

測試環境

  • win10系統
  • 4核4線程CPU
  • JDK1.8
  • com.google.guava.guava-25.1-jre.jar
  • lombok

實例

由於題目對一千萬個數沒有明確定義是什麼數,所以暫定爲int類型的隨機數。爲了對比效率,博主實現了單線程版本多線程版本,看看多線程到底有多高效。

單線程版本

單線程累加一千萬個數,代碼比較簡單,直接給出

/**
 * 單線程的方式累加
 * @param arr 一千萬個隨機數
 */
public static int singleThreadSum(int[] arr) {
    long start = System.currentTimeMillis();
    int sum = 0;
    int length = arr.length;
    for (int i = 0; i < length; i++) {
        sum += arr[i];
    }
    long end = System.currentTimeMillis();
    log.info("單線程方式計算結果:{}, 耗時:{} 秒", sum, (end - start) / 1000.0);
    return sum;
}

多線程版本

多線程的版本涉及到線程池(開啓多個線程)、CountDownLatch(主線程等待子線程執行完成)等工具的使用,所以稍微複雜一些。

// 每個task求和的規模
private static final int SIZE_PER_TASK = 200000;
// 線程池
private static ThreadPoolExecutor executor = null;

static {
    // 核心線程數 CPU數量 + 1
    int corePoolSize = Runtime.getRuntime().availableProcessors() + 1;
    executor = new ThreadPoolExecutor(corePoolSize, corePoolSize, 3, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>());
}

/**
 * 多線程的方式累加
 *
 * @param arr 一千萬個隨機數
 * @throws InterruptedException
 */
public static int concurrencySum(int[] arr) throws InterruptedException {
    long start = System.currentTimeMillis();
    LongAdder sum = new LongAdder();
    // 拆分任務
    List<List<int[]>> taskList = Lists.partition(Arrays.asList(arr), SIZE_PER_TASK);
    // 任務總數
    final int taskSize = taskList.size();
    final CountDownLatch latch = new CountDownLatch(taskSize);
    for (int i = 0; i < taskSize; i++) {
        int[] task = taskList.get(i).get(0);
        executor.submit(() -> {
            try {
                for (int num : task) {
                	// 把每個task中的數字累加
                    sum.add(num);
                }
            } finally {
            	// task執行完成後,計數器減一
                latch.countDown();
            }
        });
    }
    // 主線程等待所有子線程執行完成
    latch.await();
    long end = System.currentTimeMillis();
    log.info("多線程方式計算結果:{}, 耗時:{} 秒", sum, (end - start) / 1000.0);
    // 關閉線程池
    executor.shutdown();
    return sum.intValue();
}

由於代碼中有了詳細的註釋,所以不再贅述。

main方法

main方法也比較簡單,主要產生1千萬個隨機數,再調用兩個方法即可。

// 求和的個數
private static final int SUM_COUNT = 10000000;

public static void main(String[] args) throws InterruptedException {
    Random random = new Random();
    int[] arr = new int[SUM_COUNT];
    for (int i = 0; i < SUM_COUNT; i++) {
        arr[i] = random.nextInt(200);
    }

    // 多線程版本
    concurrencySum(arr);
    // 單線程版本
    singleThreadSum(arr);
}

第8行代碼random.nextInt(200)爲什麼是200?
因爲 1kw * 200 = 20 億 < Integer.MAX_VALUE,所以累加結果不會溢出

終於到了測試效率的時候了,是騾子是馬,拉出來溜溜。
信心滿滿的我,點擊了run,得到了如下結果

22:13:31.068 [main] INFO com.sicimike.concurrency.EfficientSum - 多線程方式計算結果:995523090, 耗時:0.133 秒
22:13:31.079 [main] INFO com.sicimike.concurrency.EfficientSum - 單線程方式計算結果:995523090, 耗時:0.006 秒

可能是我打開的方式不對…

但是

經過了多次運行,以及調整線程池參數之後的多次運行,總是得出不忍直視的運行結果。
多線程方式運行時間穩定在0.130秒左右,單線程運行方式穩定在0.006秒左右。

多線程改進

前文多線程的版本中使用了LongAdder類,由於LongAdder類在底層使用了大量的cas操作,線程競爭非常激烈時,效率會有不同程度的降低。所以在改進本例中多線程的版本時,不使用LongAdder類,而是更適合當前場景的方式。

/**
 * 多線程的方式累加(改進版)
 *
 * @param arr 一千萬個隨機數
 * @throws InterruptedException
 */
public static int concurrencySum(int[] arr) throws InterruptedException {
    long start = System.currentTimeMillis();
    int sum = 0;
    // 拆分任務
    List<List<int[]>> taskList = Lists.partition(Arrays.asList(arr), SIZE_PER_TASK);
    // 任務總數
    final int taskSize = taskList.size();
    final CountDownLatch latch = new CountDownLatch(taskSize);
    // 相當於LongAdder中的Cell[]
    int[] result = new int[taskSize];
    for (int i = 0; i < taskSize; i++) {
        int[] task = taskList.get(i).get(0);
        final int index = i;
        executor.submit(() -> {
            try {
                for (int num : task) {
                	// 各個子線程分別執行累加操作
                	// result每一個單元就是一個task的累加結果
                    result[index] += num;
                }
            } finally {
                latch.countDown();
            }
        });
    }
    // 等待所有子線程執行完成
    latch.await();
    for (int i : result) {
    	// 把子線程執行的結果累加起來就是最終的結果
        sum += i;
    }
    long end = System.currentTimeMillis();
    log.info("多線程方式計算結果:{}, 耗時:{} 秒", sum, (end - start) / 1000.0);
    // 關閉線程池
    executor.shutdown();
    return sum;
}

執行改進後的方法,得到如下結果:

22:46:05.085 [main] INFO com.sicimike.concurrency.EfficientSum - 多線程方式計算結果:994958790, 耗時:0.049 秒
22:46:05.094 [main] INFO com.sicimike.concurrency.EfficientSum - 單線程方式計算結果:994958790, 耗時:0.006 秒

多次運行,以及調整線程池參數之後的多次運行,結果也趨於穩定。
多線程方式運行時間穩定在0.049秒左右,單線程運行方式穩定在0.006秒左右

從0.133秒到0.049秒,效率大概提升了170%

思考

改進後的代碼不僅沒有解決單線程爲什麼比多線程快的問題,反而還多了一個問題:

爲什麼隨隨便便引入一個數組,竟然比Doug Lea寫的LongAdder還快?

因爲LongAdder是一個通用的工具類,很好的平衡了時間和空間的關係,所以在各種場景下都能有較好的效率。而本例中的result數組,一千萬個數字被分成了多少個task,數組的長度就是多少,每個task的結果都存在獨立的數組項,不存在競爭,但是佔用了更多的空間,所以時間效率更高,也就是拿空間換時間的思想。

至於爲什麼單線程比多線程快,這其實並不難解釋。因爲單線程沒有上下文切換,加上累加場景比較簡單,每個task執行時間很短,所以單線程更快很正常。

stream方式

stream是JDK1.8提供的語法糖,也是單線程的。關於stream的用法,大家自行了解即可。主要用來和後文的parallel stream進行對比。

public static int streamSum(List<Integer> list) {
    long start = System.currentTimeMillis();
    int sum = list.stream().mapToInt(num -> num).sum();
    long end = System.currentTimeMillis();
    log.info("stream方式計算結果:{}, 耗時:{} 秒", sum, (end - start) / 1000.0);
    return sum;
}

parallelStream方式

parallelStream見名知意,就是並行stream

public static int parallelStreamSum(List<Integer> list) {
    long start = System.currentTimeMillis();
    int sum = list.parallelStream().mapToInt(num -> num).sum();
    long end = System.currentTimeMillis();
    log.info("parallel stream方式計算結果:{}, 耗時:{} 秒", sum, (end - start) / 1000.0);
    return sum;
}

ForkJoin方式

ForkJoin框架是JDK1.7提出的,用於拆分任務計算再合併計算結果的框架。

當我們需要執行大量的小任務時,有經驗的Java開發人員都會採用線程池來高效執行這些小任務。然而,有一種任務,例如,對超過1000萬個元素的數組進行排序,這種任務本身可以併發執行,但如何拆解成小任務需要在任務執行的過程中動態拆分。這樣,大任務可以拆成小任務,小任務還可以繼續拆成更小的任務,最後把任務的結果彙總合併,得到最終結果,這種模型就是Fork/Join模型。

ForkJoin框架的使用大致分爲兩個部分:實現ForkJoin任務、執行任務

實現ForkJoin任務

自定義類繼承RecursiveTask(有返回值)或者RecursiveAction(無返回值),實現compute方法

/**
 * 靜態內部類的方式實現
 * forkjoin任務
 */
static class SicForkJoinTask extends RecursiveTask<Integer> {
    // 子任務計算區間開始
    private Integer left;
    // 子任務計算區間結束
    private Integer right;
    private int[] arr;

    @Override
    protected Integer compute() {
        if (right - left < SIZE_PER_TASK) {
        	// 任務足夠小時,直接計算
            int sum = 0;
            for (int i = left; i < right; i++) {
                sum += arr[i];
            }
            return sum;
        }
        // 繼續拆分任務
        int middle = left + (right - left) / 2;
        SicForkJoinTask leftTask = new SicForkJoinTask(arr, left, middle);
        SicForkJoinTask rightTask = new SicForkJoinTask(arr, middle, right);
        invokeAll(leftTask, rightTask);
        Integer leftResult = leftTask.join();
        Integer rightResult = rightTask.join();
        return leftResult + rightResult;
    }

    public SicForkJoinTask(int[] arr, Integer left, Integer right) {
        this.arr = arr;
        this.left = left;
        this.right = right;
    }
}

執行任務

通過ForkJoinPoolinvoke方法執行ForkJoin任務

// ForkJoin線程池
private static final ForkJoinPool forkJoinPool = new ForkJoinPool();

public static int forkJoinSum(int[] arr) {
    long start = System.currentTimeMillis();
    // 執行ForkJoin任務
    Integer sum = forkJoinPool.invoke(new SicForkJoinTask(arr, 0, SUM_COUNT));
    long end = System.currentTimeMillis();
    log.info("forkjoin方式計算結果:{}, 耗時:{} 秒", sum, (end - start) / 1000.0);
    return sum;
}

main方法

public static void main(String[] args) throws InterruptedException {
    Random random = new Random();
    int[] arr = new int[SUM_COUNT];
    List<Integer> list = new ArrayList<>(SUM_COUNT);
    int currNum = 0;
    for (int i = 0; i < SUM_COUNT; i++) {
        currNum = random.nextInt(200);
        arr[i] = currNum;
        list.add(currNum);
    }

    // 單線程執行
    singleThreadSum(arr);
    
    // Executor線程池執行
    concurrencySum(arr);
    
    // stream執行
    streamSum(list);
    
    // 並行stream執行
    parallelStreamSum(list);
    
    // forkjoin線程池執行
    forkJoinSum(arr);
}

執行結果

23:19:21.207 [main] INFO com.sicimike.concurrency.EfficientSum - 單線程方式計算結果:994917205, 耗時:0.006 秒
23:19:21.274 [main] INFO com.sicimike.concurrency.EfficientSum - 多線程方式計算結果:994917205, 耗時:0.062 秒
23:19:21.292 [main] INFO com.sicimike.concurrency.EfficientSum - stream方式計算結果:994917205, 耗時:0.018 秒
23:19:21.309 [main] INFO com.sicimike.concurrency.EfficientSum - parallel stream方式計算結果:994917205, 耗時:0.017 秒
23:19:21.321 [main] INFO com.sicimike.concurrency.EfficientSum - forkjoin方式計算結果:994917205, 耗時:0.012 秒

源代碼

代碼地址:EfficientSum.java
有興趣的同學可以自己下載源代碼後,調整各個參數運行,得到的結果不一定和我一樣。

總結

代碼寫了一大版,結果最初的問題還是沒解決。有人可能會說:博主你坑爹呢。
確實,我沒有想到更好的辦法,但是把文中的幾個問題想清楚,應該會比一道面試題更有價值。

如果哪位同學有更好的優化方式,還請不吝賜教。

參考

Java的Fork/Join任務,你寫對了嗎?

發佈了55 篇原創文章 · 獲贊 107 · 訪問量 1萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章