Sweet Snippet 之 方差計算

方差計算的簡單實現

在概率統計中,方差用於衡量一組數據的離散程度,相關的計算公式如下(總體方差):

μ=1Ni=1Nxiσ2=1Ni=1N(xiμ)2 \begin{aligned} &\mu = \frac{1}{N}\sum_{i = 1}^{N}x_i \\ &\sigma^2 = \frac{1}{N}\sum_{i = 1}^{N}(x_i - \mu)^2 \end{aligned}

其中 μ\mu 爲數據的平均值, 而 σ2\sigma^2 即是(總體)方差.

相應的實現代碼如下:

-- Lua
function average(values, count)
    local sum = 0
    
    for i = 1, count do
        sum = sum + values[i]
    end
    
    return sum / count
end

function variance(values, count)
    local average = average(values, count)
    local variance = 0
    
    for i = 1, count do
        local delta = values[i] - average 
        variance = variance + delta * delta
    end
    
    return variance / count
end

通常我們需要在獲取新樣本數據時更新方差,簡單的方法就是按照上述公式重新計算一遍,我們可以通過計算數據子集方差的方式來模擬這個過程:

-- Lua
function variance_list(values)
    local ret = {}
    
    for i = 1, #values do
        ret[i] = variance(values, i)
    end
    
    return ret
end

更好的一種方式是通過遞推來計算數據子集的方差,這需要對方差的計算公式做一些變形:

σ2=1Ni=1N(xiμ)2    σ2=1Ni=1N(xi2+μ22xiμ)    σ2=1N(i=1Nxi2+i=1Nμ2i=1N2xiμ)    σ2=1N(i=1Nxi2+Nμ22μi=1Nxi)    σ2=1N(i=1Nxi2+Nμ22Nμ2)    σ2=1N(i=1Nxi2Nμ2)    σ2=1N(i=1Nxi2N(i=1NxiN)2)    σ2=1N(i=1Nxi2(i=1Nxi)2N) \begin{aligned} &\sigma^2 = \frac{1}{N}\sum_{i = 1}^{N}(x_i - \mu)^2 \implies \\ &\sigma^2 = \frac{1}{N}\sum_{i = 1}^{N}(x_i^2 + \mu^2 - 2x_i\mu) \implies \\ &\sigma^2 = \frac{1}{N}(\sum_{i = 1}^{N}x_i^2 + \sum_{i = 1}^{N}\mu^2 - \sum_{i = 1}^{N}2x_i\mu) \implies \\ &\sigma^2 = \frac{1}{N}(\sum_{i = 1}^{N}x_i^2 + N\mu^2 - 2\mu\sum_{i = 1}^{N}x_i) \implies \\ &\sigma^2 = \frac{1}{N}(\sum_{i = 1}^{N}x_i^2 + N\mu^2 - 2N\mu^2) \implies \\ &\sigma^2 = \frac{1}{N}(\sum_{i = 1}^{N}x_i^2 - N\mu^2) \implies \\ &\sigma^2 = \frac{1}{N}(\sum_{i = 1}^{N}x_i^2 - N(\frac{\sum_{i=1}^{N}x_i}{N})^2) \implies \\ &\sigma^2 = \frac{1}{N}(\sum_{i = 1}^{N}x_i^2 - \frac{(\sum_{i=1}^{N}x_i)^2}{N}) \end{aligned}

基於此,我們就可以遞推的計算數據子集的方差了,相關的計算複雜度則降低了一個數量級(O(n2)    O(n)O(n^2) \implies O(n)):

-- lua
function variance_list_recurrence(values)
    local ret = {}
    
    local pre_square_sum = 0
    local pre_sum = 0
    
    for i = 1, #values do
        local val = values[i]
        
        pre_square_sum = pre_square_sum + val * val
        pre_sum = pre_sum + val
        
        ret[i] = (pre_square_sum - (pre_sum * pre_sum / i)) / i
    end
    
    return ret
end
發佈了149 篇原創文章 · 獲贊 146 · 訪問量 25萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章