【Matlab學習手記】梯度下降法

用一個實例來理解兩種梯度下降方法。

clear; clc;
%% 一元函數梯度下降法
% 示例:f(x) = min{(x - 1)^2}
% 梯度:g(x) =  2 * (x - 1)
yita = 0.25;   % 學習率,一般設置小一點,否則容易在最小值附近震盪或者不收斂
x1 = -5 : 0.1 : 5;
y1 = (x1 - 1).^2;
iteMax = 1000;
xInit = 4;
yInit = (xInit - 1)^2;
err = 1e-6;
figure(1)
plot(x1, y1, 'b', 'LineWidth', 2)
xlim([-5, 5])
ylim([-1, 25])
hold on
plot(xInit, yInit, 'or', 'MarkerFaceColor', 'r')
for i = 1 : iteMax
    % x = x + yita * grad;
    xNew = xInit - yita * 2 * (xInit - 1);
    yNew = (xNew - 1)^2;
    % 如果增量很小,或者說梯度很小,則退出
    if abs(xNew - xInit) < err
        break;
    else
        PlotLineArrow(gca, [xInit, xNew], [yInit, yNew], 'r', 'r')
        xInit = xNew;
        yInit = yNew;
        disp(['第', num2str(i), '次迭代結果:', num2str(xInit)]);
        plot(xNew, yNew, 'or', 'MarkerFaceColor', 'r')        
    end
end
hold off
%% 多元函數梯度下降法
% 示例:f(x) = min{x1^2 + x2^2}
% 梯度:g(x) = [2 * x1; 2 * x2]
[x, y] = meshgrid(-4:0.5:4, -4:0.5:4);
z = x.^2 + y.^2;
initX = 4;
initY = 3;
initZ = initX^2 + initY^2;
initValue = [initX; initY];
figure(2)
mesh(x, y, z);
shading interp
hold on
grad = zeros(1, 2);
e = 0.1;
yita = 5;   % Adagrad 更快收斂
for i = 1 : iteMax
    % 標準的梯度法  x = x + yita * grad;
%     newValue = initValue - yita * [2 * initX; 2 * initY];
    % Adagrad 法    x = x + yita * inv(G) * grad;
    grad = grad + [(2 * initX)^2, (2 * initY)^2];
    newValue = initValue - yita * diag(1 ./ sqrt(grad + e)) * [2 * initX; 2 * initY];
    % 如果增量很小,或者說梯度很小,則退出
    if norm(newValue - initValue) < err
        break;
    else
        newX = newValue(1);
        newY = newValue(2);
        newZ = newX^2 + newY^2;
        plot3([initX, newX], [initY, newY], [initZ, newZ], '-or', 'MarkerFaceColor', 'r')
        initValue = newValue;
        initX = newX;
        initY = newY;
        initZ = newZ;
        disp(['第', num2str(i), '次迭代結果:', num2str(newValue')]);    
    end
end
hold off

輸出結果

第1次迭代結果:2.5
第2次迭代結果:1.75
第3次迭代結果:1.375
第4次迭代結果:1.1875
第5次迭代結果:1.0938
第6次迭代結果:1.0469
第7次迭代結果:1.0234
第8次迭代結果:1.0117
第9次迭代結果:1.0059
第10次迭代結果:1.0029
第11次迭代結果:1.0015
第12次迭代結果:1.0007
第13次迭代結果:1.0004
第14次迭代結果:1.0002
第15次迭代結果:1.0001
第16次迭代結果:1
第17次迭代結果:1
第18次迭代結果:1
第19次迭代結果:1
第20次迭代結果:1
第21次迭代結果:1
第1次迭代結果:-0.9961     -1.9931
第2次迭代結果:0.21124      0.7711
第3次迭代結果:-0.044461    -0.27468
第4次迭代結果:0.009355    0.096817
第5次迭代結果:-0.0019683    -0.03408
第6次迭代結果:0.00041415    0.011994
第7次迭代結果:-8.7139e-05  -0.0042213
第8次迭代結果:1.8335e-05   0.0014857
第9次迭代結果:-3.8577e-06 -0.00052286
第10次迭代結果:8.1168e-07  0.00018402
第11次迭代結果:-1.7078e-07 -6.4762e-05
第12次迭代結果:3.5933e-08  2.2793e-05
第13次迭代結果:-7.5606e-09 -8.0216e-06
第14次迭代結果:1.5908e-09  2.8231e-06
第15次迭代結果:-3.3471e-10 -9.9357e-07
第16次迭代結果:7.0425e-11  3.4968e-07

說明:繪製帶箭頭的直線函數

https://blog.csdn.net/u012366767/article/details/99568619

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章