摘要: 本文首先回顧了導數的基本概念,然後初步書寫了計算函數導數的程序函數,並根據計算機特點對函數進行了改進以達到工程實現。
關鍵詞: 導數、工程實現
本文默認你對導數有一定了解,所介紹的函數默認是可導的。
前言
在人工智能領域,深度學習相關研究一直在如火如荼地進行着。基本上所有的深度學習算法的都使用了反向傳播(Backpropagation, BP)算法。在反向傳播中更新參數的過程中少不了的一步就是計算梯度值,計算梯度值少不了對函數進行求導計算某點的導數。用了那麼多高大上算法框架,如何用Python書寫計算任一函數任一點的導數?可能就沒有多少人知道了,如果讓你去解決這個問題你該怎麼辦?
1 導數回顧
簡單地說導數就是某個瞬間的變化量。例如汽車的加速度就是速度對時間的導數。即:
如果這個你都忘了,就去搜一搜吧。導數的數學定義式如下:
即函數f(x)關於x的導數。
2 程序初實現
有了上面的公式,你會感覺用Python實現就很簡單了,我們先使用以下方式實現,這種方式稱爲數值微分(numerical differentiation近似求解函數的導數)
def numerical_diff_or(f, x):
"""
對函數f在點x處求導
f:一個函數
x:函數定義域內一個點
"""
h = 1e-50
return (f(x + h) - f(x))/h
看起來是不是很簡單。參數說明:第一個參數f即:需要求導的函數,第二個參數x爲:需要x點上計算導數。但是上面的書寫存在幾個需要改進的地方。
-
在實際工程實現中,計算機保留浮點數是有精度限制的,並且存在舍入誤差,如下:
print(np.float32(1e-50)) print(np.float64(1e-50))
輸出的結果分別爲:0.0,1e-50,當選擇使用np.float64存儲數據的話會佔用較大的內存,而使用np.float32保存數據時最終會使用0.0計算,顯然不是我們想要的。其實我們選用一個相對比較小的數字即可然後使用另一種方式計算也會得到一個近似的結果,這裏我們先選擇1e-10
-
如下圖,因爲計算不能保存無限接近於0的數字,也必然存在誤差。爲了減少這個誤差,一種改進的計算方式是計算f在(x + h)和(x-h)之間的差分。由於這種方法計算式以x爲中心,計算它左右兩邊的差分,所以也稱爲中心差分。
改進後的代碼如下:
def numerical_diff(f, x):
"""
對函數f在點x處求導
f:一個函數
x:函數定義域內一個點
"""
h = 1e-10
return (f(x + h) - f(x-h))/(2*h)
3 一個數值微分案例
現在我們創建一個函數如下:
本實驗所需Python包如下:
import numpy as np
import matplotlib.pyplot as plt
對應的Python代碼如下:
def function_1(x):
"""
一個測試函數
x:自變量
"""
return 0.1*x + 0.01*x**2 + 0.001*x**3
來看看原函數是什麼樣子的:(已導入相關matplotlib.pyplot 和 numpy包)
x = np.arange(0, 20, 0.1)
y = function_1(x)
def plot_func1():
"""
繪製函數function_1()圖形
"""
plt.xlabel('x')
plt.ylabel('$f(X)$')
plt.plot(x, y)
plt.show()
plot_func1()
現在我們分別計算x=5, x=10處的導數如下:
print(numerical_diff(function_1, 5))
print(numerical_diff(function_1, 10))
結果分別爲:0.275000022753602、0.6000000496442226
f(x)的導數如下嚴格來說是這樣的:
我們再將x=5,x=10帶入上式計算
def func1_diff(x):
"""
func1函數嚴格導數
"""
print(0.1+0.02*x+0.003*x**2)
func1_diff(5)
func1_diff(10)
結果分別爲: 0.275、0.6000000000000001。雖然嚴格來說還是有一定不同,但是兩者相比誤差已經非常非常小了,基本上可以認爲是相等的了。在對應點繪製的切線如下圖:
對應源碼如下:
def plot_tangent(ax, x, y, x0):
"""
使用numerical_diff()計算得到的導數值繪製切線
"""
# 繪製原數據
ax.set_xlabel('x')
ax.set_ylabel('$f(X)$')
ax.plot(x, y)
y0 = function_1(x0) # 計算切點座標
k = numerical_diff(function_1, x0) # 得到斜率,即導數值
# 設置座標軸範圍
ax.set_ylim(0, function_1(12))
ax.set_xlim(0, 12)
# 繪製切線所需的點
y_diff = k*(x-x0) + y0
# 繪製圖形
ax.plot(x, y_diff)
# 繪製切點垂線
ax.vlines(x0, 0, y0, linestyle="--")
ax.hlines(y0, 0, x0, linestyle="--")
# 設置圖標題
ax.set_title("the tangent of $x_0$={}".format(x0))
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
for x0, ax in zip([5, 10], axes):
plot_tangent(ax, x, y, x0)
plt.show()
4 總結
綜上感覺不難,但是需要注意細節。