【Matlab學習手記】基於最小二乘的非線性擬合

用一個實例來理解基於最小二乘的非線性擬合問題。

原理部分

代碼部分

clear; clc;
M = 2000;
t = 0.3 * (1 : M)';
rng('default');
ratio = 10;
noise = ratio * randn(M, 1);
Et = 1000 * exp(-t / 50) + 10 + noise;
p1 = LSFittingT2Free(t, Et);
p2 = LSFittingT2Penalty(t, Et);
fit1 = p1(1) * exp(-t / p1(2)) + p1(3);
fit2 = p2(1) * exp(-t / p2(2)) + p2(3);
disp([p1; p2]);
plot(t, Et, '.k', 'LineWidth', 1.5)
hold on
plot(t, fit1, 'LineWidth', 1.5)
plot(t, fit2, 'LineWidth', 1.5)
hold off
legend('Model', 'Fitting1', 'Fitting2')
function result = LSFittingT2Free(x, y)
% Levenberg_Marquardt LS Fitting
% y = a^2 * exp(-x / b^2) + c^2
%{
clear; clc;
M = 2000;
t = 0.3 * (1 : M)';
rng('default');
ratio = 10;
noise = ratio * randn(M, 1);
Et = 1000 * exp(-t / 50) + 10 + noise;
p = LSFittingT2Free(t, Et);
fit = p(1) * exp(-t / p(2)) + p(3);
plot(t, Et, '.k', 'LineWidth', 1.5)
hold on
plot(t, fit, 'LineWidth', 1.5)
hold off
legend('Model', 'Fitting')
%}
% 按列排
x = x(:);
y = y(:);
% 歸一化
maxValue = max(y);
y = y / maxValue;
M = length(x);
% Least Squares Minimization 
rng('default');
% 初始值
a = sqrt(rand);
b = sqrt(rand);
c = sqrt(rand);
% 變量個數
nParam = 3;
% 殘差
r = y - (a^2 * exp(-x / b^2) + c^2);
% 目標函數
f = r'*r;   
% 正則化因子初值
lambda = 1;     
% 迭代次數
it = 0;         
% 更新標記
updateFlag = true;                  
 % 最大迭代次數
maxIter = 100;    
% 迭代計算
while it < maxIter
    it = it + 1;
    if updateFlag
        Ja = 2 * a * exp(-x / b^2);
        Jb = 2 * a^2 * x .* exp(-x / b^2) / b^3;
        Jc = 2 * c * ones(M, 1);
        J = [Ja, Jb, Jc];
        g = -2 * J' * r;
        H = 2 * (J'*J);
    end
    Hess = H + lambda * eye(nParam);
    s = -Hess \ g;
    a1 = a + s(1);
    b1 = b + s(2);
    c1 = c + s(3);
    r1 = y - (a1^2 * exp(-x / b1^2) + c1^2);
    f1 = r1'*r1;
    fdr = (f - f1) / f;
    if fdr > 0
        a = a1;
        b = b1;
        c = c1;
        f = f1;
        r = r1;
        lambda = 0.1 * lambda;
        updateFlag = true;
    else
        lambda = 10 * lambda;
        updateFlag = false;
    end
    if max(abs(s)) < 1e-6
        disp(['反演收斂,迭代次數:' num2str(it)]);
        break;
   elseif it == maxIter
        disp('反演可能不收斂;');
    end
end
result = [a^2 * maxValue, b^2, c^2 * maxValue];
function result = LSFittingT2Penalty(x, y)
% Levenberg_Marquardt LS Fitting
% y = a^2 * exp(-x / b^2) + c^2
%{
clear; clc;
M = 2000;
t = 0.3 * (1 : M)';
rng('default');
ratio = 10;
noise = ratio * randn(M, 1);
Et = 1000 * exp(-t / 50) + 10 + noise;
p = LSFittingT2Free(t, Et);
fit = p(1) * exp(-t / p(2)) + p(3);
plot(t, Et, '.k', 'LineWidth', 1.5)
hold on
plot(t, fit, 'LineWidth', 1.5)
hold off
legend('Model', 'Fitting')
%}
% 按列排
x = x(:);
y = y(:);
M = length(x);
% 初始值
a0 = rand;
b0 = rand;
c0 = rand;
% 變量個數
nParam = 3;
% 殘差
r0 = y - (a0 * exp(-x / b0) + c0);
% 障礙罰函數的收縮係數個數
numLoop = 5;
% 最大迭代次數
maxIter = 100;
% 障礙罰函數收縮係數
mu = 10;
k = 0;
while k < numLoop
    k = k + 1;
    mu = mu * 0.1;
    % Least Squares Minimization
    a = a0;
    b = b0;
    c = c0;
    r = r0;
    % 目標函數值
    f = r'*r - mu * (log(a) + log(b) + log(c));
    % 正則化因子初值
    lambda = 1;
    % 迭代次數
    it = 0;
    % 更新標記
    updateFlag = true;
    % 迭代計算
    while it < maxIter
        it = it + 1;
        if updateFlag
            Ja = exp(-x / b);
            Jb = a * x .* exp(-x / b) / b^2;
            J = [Ja, Jb, ones(M, 1)];
            g = -2 * J' * r - mu * [1/a; 1/b; 1/c];
            H = 2 * (J'*J) - mu * diag([-1/a^2, -1/b^2, -1/c^2]);
        end
        s = -(H + lambda * eye(nParam)) \ g;
        while a + s(1) < 0 || b + s(2) < 0 || c + s(3) < 0
            s = 0.5 * s;
        end
        a1 = a + s(1);
        b1 = b + s(2);
        c1 = c + s(3);
        r1 = y - (a1 * exp(-x / b1) + c1);
        f1 = (r1'*r1) - mu * (log(a1) + log(b1) + log(c1));
        fdr = (f - f1) / f;
        if fdr > 0
            a = a1;
            b = b1;
            c = c1;
            f = f1;
            r = r1;
            lambda = 0.1 * lambda;
            updateFlag = true;
        else
            lambda = 10 * lambda;
            updateFlag = false;
        end
        if max(abs(s)) < 1e-6
            disp(['反演收斂,迭代次數:' num2str(it)]);
            break;
        elseif it == maxIter
            disp('反演可能不收斂;');
        end
    end
    history.a(k) = a;
    history.b(k) = b;
    history.c(k) = c;
    history.ssr(k) = r'*r;
end
[~, minIndex] = min(history.ssr);
a = history.a(minIndex);
b = history.b(minIndex);
c = history.c(minIndex);
result = [a, b, c];

結果部分

    程序輸出結果如下,和Matlab擬合工具箱結果對比,完全一致。

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章