CupCnn是一個用java寫的卷積神經網絡。
支持L1、L2正則化
正則化的理論非常複雜,推導過程也比較繁瑣,但是實現確實異常的容易,主要體現在權重的衰減。通俗的講,就是我們每次在更新權重w的時候,可以的讓他比應該的大小減小一點。
// TODO Auto-generated method stub
float[] wData = w.getData();
float[] gradData = gradient.getData();
if(mode==GMode.L2) {
for(int j=0;j<w.getSize();j++){
//添加l2衰減
wData[j] = (1.0f-lr*lamda)*wData[j] - lr*gradData[j];
}
}else if(mode==GMode.L1){
for(int j=0;j<w.getSize();j++){
//添加l1衰減
if(wData[j]>=0) {
wData[j] = wData[j] - lr*lamda - lr*gradData[j];
}else {
wData[j] = wData[j] + lr*lamda - lr*gradData[j];
}
}
}else {
for(int j=0;j<w.getSize();j++){
wData[j] -= lr*gradData[j];
}
}
這裏lamda是一個很小的數,你可以自己根據實際情況設置,設置是在創建Optimizer的時候指定的,比如在SGDOptimizer的構造函數中指定:
public SGDOptimizer(float lr,Optimizer.GMode mode,float lamda){
super(lr,mode,lamda);
}
參數解釋如下:
- lr 學習速率
- mode L1或者L2
- 衰減因子
也可以不使用正則化,直接用如下構造函數即可:
public SGDOptimizer(float lr){
super(lr);
}
該構造函數只需傳入學習速度。
實現標準卷積
之前實現的卷積其實不能稱之爲是標準編輯,它更像是深度可分離卷積。因此,我重新實現了標準卷積並將原來的卷積改爲了深度可分離卷積。分別命名爲Conv2dLayer和DeepWiseConv2dLayer。
重構了各個層的輸入參數
以前,使用BlobParams最爲傳入參數,很不自由:
InputLayer layer1 = new InputLayer(network,new BlobParams(network.getBatch(),1,28,28));
network.addLayer(layer1);
ConvolutionLayer conv1 = new ConvolutionLayer(network,new BlobParams(network.getBatch(),6,28,28),new BlobParams(1,6,3,3));
conv1.setActivationFunc(new ReluActivationFunc());
network.addLayer(conv1);
PoolMaxLayer pool1 = new PoolMaxLayer(network,new BlobParams(network.getBatch(),6,14,14),new BlobParams(1,6,2,2),2,2);
network.addLayer(pool1);
ConvolutionLayer conv2 = new ConvolutionLayer(network,new BlobParams(network.getBatch(),12,14,14),new BlobParams(1,12,3,3));
conv2.setActivationFunc(new ReluActivationFunc());
network.addLayer(conv2);
重構後:
InputLayer layer1 = new InputLayer(network,28,28,1);
network.addLayer(layer1);
Conv2dLayer conv1 = new Conv2dLayer(network,28,28,1,6,3,1);
conv1.setActivationFunc(new ReluActivationFunc());
network.addLayer(conv1);
PoolMaxLayer pool1 = new PoolMaxLayer(network,28,28,6,2,2);
network.addLayer(pool1);
//這裏嘗試使用深度可分離卷積神經網絡
DeepWiseConv2dLayer dwconv1 = new DeepWiseConv2dLayer(network,14,14,6,6,3,1);
network.addLayer(dwconv1);
Conv2dLayer conv2 = new Conv2dLayer(network,14,14,6,6,1,1);
conv2.setActivationFunc(new ReluActivationFunc());
network.addLayer(conv2);
PoolMeanLayer pool2 = new PoolMeanLayer(network,14,14,6,2,2);
network.addLayer(pool2);
FullConnectionLayer fc1 = new FullConnectionLayer(network,7*7*6,256);
fc1.setActivationFunc(new ReluActivationFunc());
network.addLayer(fc1);
FullConnectionLayer fc2 = new FullConnectionLayer(network,256,10);
fc2.setActivationFunc(new ReluActivationFunc());
network.addLayer(fc2);
SoftMaxLayer sflayer = new SoftMaxLayer(network,10);
network.addLayer(sflayer);
如上構建一個深度可分離卷積,參數更加簡單易懂。
新增cifar10的例子
cifar10的例子驗證了CupCnn在處理彩色圖像的時候也表現良好。
新增MSELoss
MSELoss是最簡單的二次誤差損失函數,我發現它比LogLikeLoss表現好,所以新增了它。
public class MSELoss extends Loss{
@Override
public float loss(Blob label, Blob output) {
// TODO Auto-generated method stub
float[] labelData = label.getData();
float[] outputData = output.getData();
float loss = 0.0f;
for (int i = 0; i < label.getSize(); ++i) {
loss += (labelData[i] - outputData[i]) * (labelData[i] - outputData[i]);
}
return loss /label.getHeight();
}
@Override
public void diff(Blob label, Blob output, Blob diff) {
// TODO Auto-generated method stub
float[] labelData = label.getData();
float[] outputData = output.getData();
float[] diffData = diff.getData();
int width = label.getWidth();
int height = label.getHeight();
float factor = 2;
diff.fillValue(0.0f);
for(int n=0;n<height;n++){
for(int os=0;os<width;os++){
diffData[n*width+os] += factor*(outputData[n*width+os]-labelData[n*width+os]);
}
}
}
}
新增線程池併發加速
沒有實現多線程之前,確實訓練很慢。所以還是給CupCnn增加了線程池併發加速,使用方式極其簡單,只需要在常見Network後,設置線程池的數目即可:
network = new Network();
network.setThreadNum(6);
新增SGDMOptimizer
SGDMOptimizer就是使用了動量思想的SGDOptimizer。所謂動量就是這次改變參數,不光要和當前的梯度相關,還要和上一次的梯度相關。相關性主要通過momentum參數決定,momentum越接近1,相關性卻大,越接近0,相關性越小。
package cupcnn.optimizer;
import java.util.HashMap;
import java.util.List;
import cupcnn.data.Blob;
import cupcnn.optimizer.Optimizer.GMode;
/*
* SGD with momentum
*/
public class SGDMOptimizer extends Optimizer {
private float momentum = 0.9f;
private HashMap<Blob,Blob> privMap = new HashMap();
public SGDMOptimizer(float lr,float mententum){
super(lr);
this.momentum = mententum;
}
/*
* lamda是衰減權重,是一個很小的數字
* */
public SGDMOptimizer(float lr,Optimizer.GMode mode,float lamda,float mententum){
super(lr,mode,lamda);
this.momentum = mententum;
}
@Override
public void updateB(Blob b, Blob gradient) {
// TODO Auto-generated method stub
Blob priv = privMap.get(b);
if(priv == null) {
priv = new Blob(b,false);
privMap.put(b, priv);
}
float[] privData = priv.getData();
float[] bData = b.getData();
float[] gradData = gradient.getData();
for(int j=0;j<b.getSize();j++){
float V = momentum*privData[j]-lr*gradData[j];
bData[j] += V;
privData[j] = V;
}
}
@Override
public void updateW(Blob w, Blob gradient) {
// TODO Auto-generated method stub
Blob priv = privMap.get(w);
if(priv == null) {
priv = new Blob(w,false);
privMap.put(w, priv);
}
float[] privData = priv.getData();
float[] wData = w.getData();
float[] gradData = gradient.getData();
if(mode==GMode.L2) {
for(int j=0;j<w.getSize();j++){
//添加l2衰減
float V = momentum*privData[j]-lr*lamda*wData[j] - lr*gradData[j];
wData[j] += V;
privData[j] = V;
}
}else if(mode==GMode.L1){
for(int j=0;j<w.getSize();j++){
//添加l1衰減
float V = 0;
if(wData[j]>=0) {
V = momentum*privData[j] - lr*lamda - lr*gradData[j];
}else {
V = momentum*privData[j] + lr*lamda - lr*gradData[j];
}
wData[j] += V;
privData[j] = V;
}
}else {
for(int j=0;j<w.getSize();j++){
float V = momentum*privData[j]-lr*gradData[j];
wData[j] += V;
privData[j] = V;
}
}
}
}
所有的double改爲了float
double改爲float應該能讓神經網絡運行速度更快一點。
修改了訓練時測試的邏輯
之前是沒隔一段時間,將剛纔拿來訓練的那一批數據(一個batch)拿來測試,測試速度很快,而且往往準確度很高。但是訓練完後在測試集上一測試發現準確率一般,所以,乾脆改爲了沒訓練一個epoe,就在測試集上進行一次完整的測試,這樣能更好的觀察訓練的情況,也更容易發現過擬合的情況。
最後
CupCnn在github中上傳了mnist和cifar10的數據集,因此clone的時候會比較慢,但是好處是,clone後即可以訓練與測試。
CupCnn在mnist上能輕鬆達到98%的準確率,可能只需要4~5個epoe就能達到。如果你有耐心的話,99%的準確率也是沒問題的。比如,以下是一個達到99%的模型:
private void buildConvNetwork(){
InputLayer layer1 = new InputLayer(network,28,28,1);
network.addLayer(layer1);
Conv2dLayer conv1 = new Conv2dLayer(network,28,28,1,10,3,1);
conv1.setActivationFunc(new ReluActivationFunc());
network.addLayer(conv1);
PoolMaxLayer pool1 = new PoolMaxLayer(network,28,28,10,2,2);
network.addLayer(pool1);
Conv2dLayer conv2 = new Conv2dLayer(network,14,14,10,10,3,1);
conv2.setActivationFunc(new ReluActivationFunc());
network.addLayer(conv2);
PoolMeanLayer pool2 = new PoolMeanLayer(network,14,14,10,2,2);
network.addLayer(pool2);
FullConnectionLayer fc1 = new FullConnectionLayer(network,7*7*10,256);
fc1.setActivationFunc(new ReluActivationFunc());
network.addLayer(fc1);
FullConnectionLayer fc2 = new FullConnectionLayer(network,256,10);
fc2.setActivationFunc(new ReluActivationFunc());
network.addLayer(fc2);
SoftMaxLayer sflayer = new SoftMaxLayer(network,10);
network.addLayer(sflayer);
}
public void buildNetwork(int numOfTrainData){
//首先構建神經網絡對象,並設置參數
network = new Network();
network.setThreadNum(8);
network.setBatch(20);
network.setLrAttenuation(0.9f);
//network.setLoss(new LogLikeHoodLoss());
//network.setLoss(new CrossEntropyLoss());
network.setLoss(new MSELoss());
optimizer = new SGDOptimizer(0.1f);
network.setOptimizer(optimizer);
//buildFcNetwork();
buildConvNetwork();
network.prepare();
}
訓練過程中神經網絡的輸出如下:
training...... please wait for a moment!
............................................................
training...... epoe: 0 lossValue: 0.15509754 lr: 0.1 cost 117179
testing...... please wait for a moment!
test accuracy is 0.8785 correctCount 8785 allCount 10000
............................................................
training...... epoe: 1 lossValue: 0.08268655 lr: 0.089999996 cost 115667
testing...... please wait for a moment!
test accuracy is 0.9761 correctCount 9761 allCount 10000
............................................................
training...... epoe: 2 lossValue: 0.06630586 lr: 0.08099999 cost 114282
testing...... please wait for a moment!
test accuracy is 0.9818 correctCount 9818 allCount 10000
............................................................
training...... epoe: 3 lossValue: 0.015823135 lr: 0.07289999 cost 114677
testing...... please wait for a moment!
test accuracy is 0.986 correctCount 9860 allCount 10000
............................................................
training...... epoe: 4 lossValue: 0.006677883 lr: 0.06560999 cost 115048
testing...... please wait for a moment!
test accuracy is 0.9843 correctCount 9843 allCount 10000
............................................................
training...... epoe: 5 lossValue: 7.9412933E-4 lr: 0.05904899 cost 114657
testing...... please wait for a moment!
test accuracy is 0.9858 correctCount 9858 allCount 10000
............................................................
training...... epoe: 6 lossValue: 0.010126321 lr: 0.05314409 cost 114791
testing...... please wait for a moment!
test accuracy is 0.9863 correctCount 9863 allCount 10000
............................................................
training...... epoe: 7 lossValue: 0.11289799 lr: 0.04782968 cost 116797
testing...... please wait for a moment!
test accuracy is 0.9873 correctCount 9873 allCount 10000
............................................................
training...... epoe: 8 lossValue: 0.007667846 lr: 0.04304671 cost 116323
testing...... please wait for a moment!
test accuracy is 0.9889 correctCount 9889 allCount 10000
............................................................
training...... epoe: 9 lossValue: 0.0069320253 lr: 0.038742036 cost 114736
testing...... please wait for a moment!
test accuracy is 0.9886 correctCount 9886 allCount 10000
............................................................
training...... epoe: 10 lossValue: 0.011851874 lr: 0.03486783 cost 116032
testing...... please wait for a moment!
test accuracy is 0.9893 correctCount 9893 allCount 10000
............................................................
training...... epoe: 11 lossValue: 0.0129155135 lr: 0.03138105 cost 115604
testing...... please wait for a moment!
test accuracy is 0.9892 correctCount 9892 allCount 10000
............................................................
training...... epoe: 12 lossValue: 0.0018233052 lr: 0.028242942 cost 119216
testing...... please wait for a moment!
test accuracy is 0.99 correctCount 9900 allCount 10000
............................................................
training...... epoe: 13 lossValue: 0.0027558927 lr: 0.025418647 cost 119067
testing...... please wait for a moment!
test accuracy is 0.9891 correctCount 9891 allCount 10000
............................................................
training...... epoe: 14 lossValue: 0.0024366616 lr: 0.02287678 cost 115927
testing...... please wait for a moment!
test accuracy is 0.9893 correctCount 9893 allCount 10000
............................................................
training...... epoe: 15 lossValue: 0.014906692 lr: 0.020589102 cost 118666
testing...... please wait for a moment!
test accuracy is 0.9903 correctCount 9903 allCount 10000
............................................................
training...... epoe: 16 lossValue: 0.012857122 lr: 0.018530192 cost 116381
testing...... please wait for a moment!
test accuracy is 0.9892 correctCount 9892 allCount 10000
............................................................
training...... epoe: 17 lossValue: 0.0077277897 lr: 0.016677173 cost 115545
testing...... please wait for a moment!
test accuracy is 0.9894 correctCount 9894 allCount 10000
............................................................
training...... epoe: 18 lossValue: 0.008005977 lr: 0.015009455 cost 115614
testing...... please wait for a moment!
test accuracy is 0.9895 correctCount 9895 allCount 10000
............................................................
training...... epoe: 19 lossValue: 0.005667642 lr: 0.01350851 cost 116124
testing...... please wait for a moment!
test accuracy is 0.9903 correctCount 9903 allCount 10000
............................................................
training...... epoe: 20 lossValue: 0.0015369358 lr: 0.012157658 cost 115027
testing...... please wait for a moment!
test accuracy is 0.9903 correctCount 9903 allCount 10000
............................................................
training...... epoe: 21 lossValue: 0.0027716388 lr: 0.010941892 cost 114647
testing...... please wait for a moment!
test accuracy is 0.9903 correctCount 9903 allCount 10000
............................................................
training...... epoe: 22 lossValue: 3.0824103E-4 lr: 0.009847702 cost 114195
testing...... please wait for a moment!
test accuracy is 0.99 correctCount 9900 allCount 10000
............................................................
training...... epoe: 23 lossValue: 0.013186147 lr: 0.008862932 cost 114099
testing...... please wait for a moment!
test accuracy is 0.9905 correctCount 9905 allCount 10000
............................................................
training...... epoe: 24 lossValue: 3.2178203E-5 lr: 0.007976639 cost 114108
testing...... please wait for a moment!
test accuracy is 0.9903 correctCount 9903 allCount 10000
............................................................
training...... epoe: 25 lossValue: 0.0058679665 lr: 0.007178975 cost 116286
testing...... please wait for a moment!
test accuracy is 0.9902 correctCount 9902 allCount 10000
............................................................
training...... epoe: 26 lossValue: 0.0012587834 lr: 0.0064610774 cost 115560
testing...... please wait for a moment!
test accuracy is 0.9904 correctCount 9904 allCount 10000
............................................................
training...... epoe: 27 lossValue: 0.011831607 lr: 0.0058149695 cost 115372
testing...... please wait for a moment!
test accuracy is 0.9904 correctCount 9904 allCount 10000
............................................................
training...... epoe: 28 lossValue: 1.4093271E-4 lr: 0.0052334727 cost 116815
testing...... please wait for a moment!
test accuracy is 0.9904 correctCount 9904 allCount 10000
............................................................
training...... epoe: 29 lossValue: 8.6647685E-4 lr: 0.0047101253 cost 116404
testing...... please wait for a moment!
test accuracy is 0.9904 correctCount 9904 allCount 10000
begin save model
save model finished
begin load model
load model finished
testing...... please wait for a moment!
test accuracy is 0.9904 correctCount 9904 allCount 10000
交流
機器學習 QQ交流羣:704153141