矩陣的簡單操作代碼實現
轉置定義(transpose):
轉置是矩陣的重要操作之一。矩陣的轉置是以對角線爲軸的鏡像,這條從左上角到右下角的對角線稱爲主對角線(main diagonal)。下圖顯示這個操作。將矩陣A轉置表示爲A^T,定義如下:
(參考 《深度學習》–[美]伊恩·古德費洛,[加]約書亞·本吉奧,[加]亞倫·庫維爾 第2章)
向量可以看作只有一列的矩陣。對應地,向量的轉置可以看作只有一行的矩陣。有時,我們通過將向量元素作爲行矩陣寫在文本行中,然後使用轉置操作將其變爲標準的列向量,來定義一個向量 (說明矩陣轉置不主要要求爲方陣),下圖演示了轉置的過程:
代碼實現:
用java的數組進行簡單實現,其實就是數組的操作:
package cn.qulei.matrix;
/**
* 實現矩陣轉置
*
* @author QuLei
*/
public class MatrixTranspose {
public static void main(String[] args) {
int[][] matrix = new int[3][4];
int k = 1;
for (int i = 0; i < matrix.length; i++) {
for (int j = 0; j < matrix[i].length; j++) {
matrix[i][j] = k++;
}
}
int[][] transposedMatrix = transpose(matrix);
System.out.println("輸出原矩陣:");
print(matrix);
System.out.println("-----------------");
System.out.println("轉置後的矩陣:");
print(transposedMatrix);
}
/**
* 轉置矩陣操作
*
* @param matrix 待轉置矩陣
* @return 轉置後的矩陣
*/
private static int[][] transpose(int[][] matrix) {
int[][] transposedMatrix = new int[matrix[0].length][matrix.length];
for (int i = 0; i < matrix.length; i++) {
for (int j = 0; j < matrix[i].length; j++) {
transposedMatrix[j][i] = matrix[i][j];
}
}
return transposedMatrix;
}
/**
*封裝打印數組方法
*
* @param matrix 待打印數組
*/
private static void print(int[][] matrix) {
for (int i = 0; i < matrix.length; i++) {
for (int j = 0; j < matrix[i].length; j++) {
System.out.print(matrix[i][j]);
System.out.print("\t");
}
System.out.println();
}
}
}
測試用例結果:
輸出原矩陣:
1 2 3 4
5 6 7 8
9 10 11 12
-----------------
轉置後的矩陣:
1 5 9
2 6 10
3 7 11
4 8 12
Process finished with exit code 0
矩陣的乘積
簡單定義:
參考:《深度學習的數學》–[日]湧井良幸,[日]湧井貞美 第2-5節
代碼實現
import org.junit.Test;
/**
* 測試矩陣相乘
*
* @author QuLei
*/
public class TestMatrix {
@Test
public void test() {
int[][] arr1 = {{2, 7}, {1, 8}};
int[][] arr2 = {{2, 8}, {1, 3}};
int[][] milt = milt(arr1, arr2);
for (int i = 0; i < milt.length; i++) {
for (int j = 0; j < milt[0].length; j++) {
System.out.print(milt[i][j] + "\t");
}
System.out.println();
}
}
/**
* 計算兩矩陣相乘
*
* @param arr1
* @param arr2
* @return
*/
int[][] milt(int[][] arr1, int[][] arr2) {
//這裏爲了實現數學上的問題,簡化了對矩陣合法性的判斷
if (arr1[0].length != arr2.length) {
throw new RuntimeException("不滿足矩陣乘法基本要求!!!");
}
int row = arr1.length;
int col = arr2[0].length;
int[][] result = new int[row][col];
for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j++) {
for (int k = 0; k < arr2.length; k++) {
result[i][j] += arr1[i][k] * arr2[k][j];
}
}
}
return result;
}
}
測試結果:
11 37
10 32