[原創]利用Java多線程計算矩陣乘法
前言
前段時間在一本操作系統書籍上,看到了可以利用多線程來計算矩陣乘法的思想。例如下圖中,A矩陣和B矩陣相乘得到C矩陣,那麼A矩陣的每一行和B矩陣的每一列的相乘和加和,都可以交給一個線程來計算,最終得到cij這個元素。A矩陣維度是m*s,B矩陣是s*n,那麼這個計算需要m*n個線程的參與,它是否一定比串行計算快呢?本文使用Java多線程一探究竟。
串行計算
矩陣乘法的串行計算方法是不難想到的。三層循環遍歷計算即可。從外到內分別遍歷A矩陣的行、B矩陣的列、A矩陣的列(即爲B矩陣的行)即可。在計算的開始和結束時刻分別獲取系統當前時刻,最後可得計算時間。以下是簡單的代碼片段。
// 串行驗證
startTime = System.currentTimeMillis();
for (int i = 0; i < A.length; i++) {
for (int j = 0; j < B[0].length; j++) {
for (int k = 0; k < A[0].length; k++)
serial_result[i][j] += A[i][k] * B[k][j];
}
}
endTime = System.currentTimeMillis();
System.out.println("串行計算開始時刻:" + (startTime));
System.out.println("串行計算結束時刻:" + (endTime));
System.out.println("串行計算運行時間:" + (endTime - startTime));
並行計算
並行計算需要考慮的問題就複雜一些。假設結果保存在C矩陣裏,期間有m*n個線程參與計算,那麼首先要保證C矩陣的計算正確性----即C矩陣是全部的子線程計算完成後得到的結果,而不是子線程還沒結束,main線程已經繼續執行並打印出了錯誤的輸出結果。如果一個算法計算速度再快,結果是錯誤的,那就毫無意義了。
這裏我採用CountDownLatch作爲計數工具。
int threadNum = A.length*B[0].length;
CountDownLatch countDownLatch = new CountDownLatch(threadNum);
這樣就聲明瞭一個初始值爲m*n個線程的countDownLatch,每當有一個線程完成其計算任務後,可調用countDownLatch實例的countDown()方法,令總線程數減1。這個操作可放在子線程的run()方法中實現。
使用for循環啓動線程之後,在main()函數中調用countDownLatch的await()方法,這個操作的作用是,只要計數器的值不爲0,其他已先計算完成的子線程就會等待直到計數器值變爲0 。計數器變爲0後,main線程就不會被阻塞。所以,這時得到的結果就是必然正確的。
按照上述思路,我順利完成了代碼的編寫。但令人意外的是,對於我測試的所有維度的矩陣,其並行計算時間均慢於串行。對於維度爲300*300的A、B矩陣,運行結果如下。隨着矩陣維度的增大,兩者的時間差距甚至越來越大:
我確信代碼邏輯沒有問題,那麼一個最有可能的猜測是,啓動這m*n個線程耗費太多時間,故對代碼又做了改進,令一個子線程由負責C矩陣中一個元素的計算改爲負責多行元素的計算,這樣就大大減少了線程數量。使用10個線程,對於同樣300*300的A、B矩陣計算時間如下:
驗證了我的猜測。
另外,當聲明的線程數量小於for循環中啓動的線程總數時,會導致await()方法提前失效,main線程和子線程交替執行,那麼有可能會導致結果錯誤,這也是需要注意的一點。
最終整體代碼如下:
import java.util.concurrent.CountDownLatch;
public class CalculateTask extends Thread {
private int[][] A;
private int[][] B;
private int index;
private int gap;
private int[][] result;
private CountDownLatch countDownLatch;
public CalculateTask(int[][] A, int[][] B, int index, int gap, int[][] result, CountDownLatch countDownLatch) {
this.A = A;
this.B = B;
this.index = index;
this.gap = gap;
this.result = result;
this.countDownLatch = countDownLatch;
}
// 計算特定範圍內的結果
public void run() {
// TODO Auto-generated method stub
for (int i = index * gap; i < (index + 1) * gap; i++)
for (int j = 0; j < B[0].length; j++) {
for (int k = 0; k < B.length; k++)
result[i][j] += A[i][k] * B[k][j];
}
// 線程數減1
countDownLatch.countDown();
}
public static void main(String[] args) throws InterruptedException {
// 聲明和初始化
long startTime;
long endTime;
int row_A = 300;
int col_A = 300;
int col_B = 300;
int[][] A = new int[row_A][col_A];
int[][] B = new int[col_A][col_B];
//存放並行計算結果
int[][] parallel_result = new int[A.length][B[0].length];
//存放串行計算結果
int[][] serial_result = new int[A.length][B[0].length];
//子線程數量
int threadNum = 10;
//子線程的分片計算間隔
int gap = A.length / threadNum;
CountDownLatch countDownLatch = new CountDownLatch(threadNum);
// 爲A和B矩陣隨機賦值
for (int i = 0; i < row_A; i++)
for (int j = 0; j < col_A; j++) {
A[i][j] = (int) (Math.random() * 100);
}
for (int i = 0; i < col_A; i++)
for (int j = 0; j < col_B; j++) {
B[i][j] = (int) (Math.random() * 100);
}
// 並行計算
startTime = System.currentTimeMillis();
for (int i = 0; i < threadNum; i++) {
CalculateTask ct = new CalculateTask(A, B, i, gap, parallel_result, countDownLatch);
ct.start();
}
countDownLatch.await();
endTime = System.currentTimeMillis();
System.out.println("並行計算開始時刻:" + (startTime));
System.out.println("並行計算結束時刻:" + (endTime));
System.out.println("並行計算運行時間:" + (endTime - startTime));
// 串行計算
startTime = System.currentTimeMillis();
for (int i = 0; i < A.length; i++) {
for (int j = 0; j < B[0].length; j++) {
for (int k = 0; k < A[0].length; k++)
serial_result[i][j] += A[i][k] * B[k][j];
}
}
endTime = System.currentTimeMillis();
System.out.println("串行計算開始時刻:" + (startTime));
System.out.println("串行計算結束時刻:" + (endTime));
System.out.println("串行計算運行時間:" + (endTime - startTime));
}
}