本文示例源碼地址:https://github.com/tianlanlandelan/DL4JTest/blob/master/src/test/java/com/dl4j/demo/Nd4jTest.java
maven安裝DL4J
pom文件引入:
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>1.0.0-beta3</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>1.0.0-beta3</version>
</dependency>
創建矩陣
//生成一個全0二維矩陣
INDArray tensorA = Nd4j.zeros(4,5);
println("全0二維矩陣",tensorA);
//生成一個全1二維矩陣
INDArray tensorB = Nd4j.ones(4,5);
println("全1二維矩陣",tensorB);
//生成一個全1二維矩陣
INDArray tensorC = Nd4j.rand(4,5);
println("隨機二維矩陣",tensorC);
運行結果:
====全0二維矩陣===
[[ 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0]]
====全1二維矩陣===
[[ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
[ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
[ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
[ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]]
====隨機二維矩陣===
[[ 0.5017, 0.9461, 0.3255, 0.2155, 0.9273],
[ 0.0239, 0.5130, 0.8028, 0.5011, 0.3680],
[ 0.3644, 0.0864, 0.0342, 0.4126, 0.5553],
[ 0.2027, 0.7989, 0.6696, 0.0402, 0.7059]]
矩陣運算–拼接
println("水平拼接若干矩陣,矩陣必須有相同的行數", Nd4j.hstack(tensorA,tensorB));
println("垂直拼接若干矩陣,矩陣必須有相同的列數", Nd4j.vstack(tensorA,tensorB));
運行結果:
[[ 0, 0, 0, 0, 0, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
[ 0, 0, 0, 0, 0, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
[ 0, 0, 0, 0, 0, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
[ 0, 0, 0, 0, 0, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]]
====垂直拼接若干矩陣,矩陣必須有相同的列數===
[[ 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0],
[ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
[ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
[ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
[ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]]
矩陣運算-加減
注意,每個運算函數都有一個加i的函數,如 add和addi,加i的函數運算後會覆蓋掉原矩陣
println("矩陣元素加上一個標量",tensorA.add(10));
println("矩陣相加",tensorA.add(tensorB));
println("矩陣元素加上標量後覆蓋原矩陣tensorA",tensorA.addi(10));
println("矩陣相減",tensorA.sub(tensorB));
運行結果:
====矩陣元素加上一個標量===
[[ 10.0000, 10.0000, 10.0000],
[ 10.0000, 10.0000, 10.0000]]
====矩陣相加===
[[ 0.2202, 0.1473, 0.1217],
[ 0.8428, 0.6761, 0.8127]]
====矩陣元素加上標量後覆蓋原矩陣tensorA===
[[ 10.0000, 10.0000, 10.0000],
[ 10.0000, 10.0000, 10.0000]]
====矩陣相減===
[[ 9.7798, 9.8527, 9.8783],
[ 9.1572, 9.3239, 9.1873]]
矩陣運算-乘除
println("矩陣對應元素相乘",tensorA.mul(tensorB));
println("矩陣元素除以一個標量",tensorA.div(2));
println("矩陣對應元素相除",tensorA.div(tensorB));
/*
矩陣A*B=C
需要注意:
1、當矩陣A的列數等於矩陣B的行數時,A與B可以相乘。
2、矩陣C的行數等於矩陣A的行數,C的列數等於B的列數。( A:2,3; B:3,4; C:2,4 )
3、乘積C的第m行第n列的元素等於矩陣A的第m行的元素與矩陣B的第n列對應元素乘積之和。
*/
println("矩陣相乘",tensorA.mmul(tensorB));
運算結果:
====矩陣對應元素相乘===
[[ 2.2015, 1.4728, 1.2173],
[ 8.4281, 6.7608, 8.1272]]
====矩陣元素除以一個標量===
[[ 5.0000, 5.0000, 5.0000],
[ 5.0000, 5.0000, 5.0000]]
====矩陣對應元素相除===
[[ 45.4231, 67.8989, 82.1506],
[ 11.8650, 14.7911, 12.3043]]
====矩陣相乘===
[[ 4.8916, 23.3161],
[ 4.8916, 23.3161]]
矩陣運算-翻轉
println("矩陣轉置",tensorB.transpose());
println("矩陣轉置後替換原矩陣tensorB",tensorB.transposei());
運算結果:
====矩陣轉置===
[[ 0.2202, 0.8428],
[ 0.1473, 0.6761],
[ 0.1217, 0.8127]]
====矩陣轉置後替換原矩陣tensorB===
[[ 0.2202, 0.8428],
[ 0.1473, 0.6761],
[ 0.1217, 0.8127]]
三維矩陣
三維矩陣和二維矩陣操作一樣:
//創建一個三維矩陣 2*3*4
INDArray tensor3d_1 = Nd4j.create(new int[]{2,3,4});
println("創建空的三維矩陣",tensor3d_1);
//創建一個隨機的三維矩陣 2*3*4
INDArray tensor3d_2 = Nd4j.rand(new int[]{2,3,4});
println("創建隨機三維矩陣",tensor3d_2);
//矩陣的每個元素減去一個標量後覆蓋原矩陣
println("矩陣元素減去一個標量",tensor3d_1.subi(-5));
//矩陣相減
println("三維矩陣相減",tensor3d_1.sub(tensor3d_2));
運算結果:
====創建空的三維矩陣===
[[[ 0, 0, 0, 0],
[ 0, 0, 0, 0],
[ 0, 0, 0, 0]],
[[ 0, 0, 0, 0],
[ 0, 0, 0, 0],
[ 0, 0, 0, 0]]]
====創建隨機三維矩陣===
[[[ 0.7030, 0.0575, 0.3288, 0.8928],
[ 0.7067, 0.4539, 0.6318, 0.8632],
[ 0.2914, 0.7980, 0.3350, 0.8783]],
[[ 0.8559, 0.7396, 0.6039, 0.1946],
[ 0.5336, 0.9253, 0.4747, 0.2658],
[ 0.9690, 0.3269, 0.0520, 0.1754]]]
====矩陣元素減去一個標量===
[[[ 5.0000, 5.0000, 5.0000, 5.0000],
[ 5.0000, 5.0000, 5.0000, 5.0000],
[ 5.0000, 5.0000, 5.0000, 5.0000]],
[[ 5.0000, 5.0000, 5.0000, 5.0000],
[ 5.0000, 5.0000, 5.0000, 5.0000],
[ 5.0000, 5.0000, 5.0000, 5.0000]]]
====三維矩陣相減===
[[[ 4.2970, 4.9425, 4.6712, 4.1072],
[ 4.2933, 4.5461, 4.3682, 4.1368],
[ 4.7086, 4.2020, 4.6650, 4.1217]],
[[ 4.1441, 4.2604, 4.3961, 4.8054],
[ 4.4664, 4.0747, 4.5253, 4.7342],
[ 4.0310, 4.6731, 4.9480, 4.8246]]]