Treevalue(0x02)——函數樹化詳細解析(上篇)

本文將對 func_treelize 這一treevalue庫中的核心功能進行詳細的原理解析。

關於treevalue的概述,可以參考之前的文章:Treevalue(0x01)——功能概述

樹化函數基本原理

在treevalue庫中, func_treelize 是核心特性之一,可以將普通的函數快速作用於樹對象上。而這一“作用”的原理是什麼呢,我們來一起看看——首先準備一個普通的函數,並加上 func_treelize 裝飾器,就像這樣

from treevalue import func_treelize


@func_treelize()
def gcd(a, b):  # GCD calculation
    print('gcd', a, b)
    while True:
        r = a % b
        a, b = b, r
        if r == 0:
            break

    return a

函數的部分是一個最大公因數的計算,並且和之前文章(Treevalue(0x01)——功能概述)中的區別在於,添加了一行 print 輸出,用於體現函數內部在整個計算過程中是如何被調用的。基於這一函數,我們進行如下的調用,可以得到對應的輸出結果

from treevalue import FastTreeValue

gcd(9, 12)
# gcd 9 12
# 3

t1 = FastTreeValue({'a': 2, 'b': 30, 'x': {'c': 4, 'd': 9}})
t2 = FastTreeValue({'a': 4, 'b': 48, 'x': {'c': 6, 'd': 54}})
gcd(t1, t2)
# gcd 30 48
# gcd 9 54
# gcd 4 6
# gcd 2 4
# <TreeValue 0x7f12950e3be0>
# ├── a --> 2
# ├── b --> 6
# └── x --> <TreeValue 0x7f1296732310>
#     ├── c --> 2
#     └── d --> 9

根據輸出語句,不難發現——經過func_treelize裝飾後的函數,在被傳入TreeValue類型的時候,會自動基於其結構將內部的數值一一對應傳入原函數,並在執行計算後組裝成與原來相同的樹結構
基於以上基本特性,func_treelize這一過程也被稱爲函數的樹化,經過樹化後的函數將滿足以下基本特性:

  1. 當所有傳入參數均爲非樹對象時,函數行爲與返回值與原函數保持嚴格一致,即樹化後的函數依然可以像原函數一樣地使用
  2. 樹化的函數本身不會對傳入的樹對象內部結構有顯式的限制,在函數的樹化邏輯中將基於傳入樹參數的結構生成最終的返回值結構。
  3. 函數的樹化邏輯部分不會對樹對象內部的值進行任何的判定與檢測,只是作爲一箇中繼器將對應的值傳入原函數並獲取運算結果

樹化函數運行機制

通過開頭章節的簡單例子展示,相信各位已經對函數的樹化有了基本的概念和了解。在本章中,將對函數的樹化過程進行更加詳細的機制分析。

機制概述

在開頭章節的例子中,展現的只是兩種最爲理想化的情況:

  1. 傳入的參數均爲非樹對象
  2. 傳入的參數均爲結構完全一致的樹對象

然而實際上,基於對“樹”這一數據結構的基本瞭解,不難發現實際上需要作出處理的情況依然有很多,包括但不限於:

  • 鍵值缺少——參與計算的某個樹對象在對應的位置上缺少了對應的鍵值,這樣的情況如何處理?例如下圖中, t2.x.d 缺失,這樣的情況該如何處理?

  • 鍵值類型不匹配——參與計算的某幾個樹對象對應位置上,有些是葉子節點值,有些是非葉子節點子樹,形成“值-子樹”之間的直接運算,這樣的情況如何定義?例如下圖中, t1.b 爲子樹但是 t2.b 爲值,這樣的情況如何定義?

  • 計算模式多樣性——當參與計算的樹對象之間的結構存在較多較大差異性時,如何設計計算策略使之能支持更多樣化的計算?例如下列的場景,如何組織對如此結構各異的樹之間的運算?

  • 數據格式多樣性——當參與計算的葉子節點值格式存在不統一時,如何處理?例如下面的場景,如何對 t1t2 下顯然不同尺寸的 torch.Tensor 進行處理?

因此,基於這些很現實的問題,我們爲樹化函數定義瞭如下的選項:

  • 模式選項(mode)——決定樹化函數的整體運行機制。
  • 繼承選項(inherit)——對鍵值類型不匹配的情況進行了定義,並提供了處理機制。
  • 缺省選項(missing)——爲鍵值缺少的情況提供了缺省值補全機制。

模式選項(mode)

模式選項是樹化函數中最爲重要的選項,其將直接決定樹化函數的主體計算邏輯。目前定義了四種常用模式:

  • 嚴格模式(STRICT)
  • 內共同模式(INNER)
  • 外共有模式(OUTER)
  • 左優先模式(LEFT)

接下來的子章節中會結合例子進行逐一介紹。

嚴格模式(STRICT)

嚴格模式是最常用的模式選項,意味着當且僅當所有樹參數在當前子樹位置上的鍵一一對應時,會將其鍵值進行一一對應地代入計算,否則拋出異常。代碼實現如下,與開頭的例子等價,模式選項的默認值即爲嚴格模式

from treevalue import func_treelize


@func_treelize(mode='strict')
def gcd(a, b):  # GCD calculation
    while True:
        r = a % b
        a, b = b, r
        if r == 0:
            break

    return a

在上述的樹化gcd函數中,完整的計算機制如下圖1所示, tr 爲樹化gcd的運算結果


(圖1,t1、t2內的鍵值可以形成一一對應)

但是當出現如下所示的參數時,則應拋出異常,因爲部分鍵存在缺失,無法形成一一對應。


(圖2,t1.b與t1.x.c缺失,無法形成一一對應)

嚴格模式是一種最爲常見的計算邏輯,適用於大部分常見情況,也是在業務邏輯上最爲順理成章的一種模式。但是對非規則結構下的計算則不能兼容,因此另外三種模式選項分別針對不同的情況來支持非規則結構下的計算。

內共同模式(INNER)

內共同模式下,僅會對全部樹參數當前子樹位置上均存在此鍵時,纔會對將其鍵值進行一一對應地代入計算,而當此鍵值在某一樹參數當前子樹位置上存在缺失情況是,則會直接忽略該組鍵值。代碼實現如下,將 mode 設置爲 inner 即可

from treevalue import func_treelize


@func_treelize(mode='inner')
def gcd(a, b):  # GCD calculation
    while True:
        r = a % b
        a, b = b, r
        if r == 0:
            break

    return a

例如對圖2所示的例子,在內共同模式下可以正常計算,如圖3所示


(圖3,t1.x.c和t2.b因爲t2.x.c和t1.b的缺失而被忽略)

內共同模式會忽略無法形成對應的多餘值,可以確保在幾乎所有情況下均能得出計算結果而不會產生錯誤。但是會不可避免地造成部分信息丟失,而在一部分情況下這是不可接受的,因此請根據實際需求進行選擇。

外共有模式(OUTER)

外共有模式下,只要在任意一個樹參數的當前子樹位置上存在此鍵值,則會將其進行代入計算。而對於缺失的值,則會使用缺省選項中設置的值或生成器進行獲取並代入。代碼實現如下,將 mode 設置爲 outer 即可,並將缺省選項設置爲值 1

from treevalue import func_treelize


@func_treelize(mode='outer', missing=1)
def gcd(a, b):  # GCD calculation
    while True:
        r = a % b
        a, b = b, r
        if r == 0:
            break

    return a

例如對圖2所示的例子,在外共有模式下可以正常計算,如圖4所示


(圖4,t1.b和t1.x.c缺失,將使用缺省選項指定的默認值1)

外共有模式將會讓所有的數值參與運算,但是在絕大部分情況下均依賴缺省選項的設置,因此在使用前請確保缺省選項的正確配置,以及業務邏輯上的自洽。

左優先模式(LEFT)

左優先模式下,參與運算的鍵值將以全部樹參數中最左的一項爲參考。其中最左的一項定義爲,在python函數調用的位置參數(postional argument)中,如果存在樹參數,則取最左的一項;如果不存在,則在函數調用的鍵值參數(key-word argument)紅,取字典序最小的一項。代碼實現如下,將 mode 設置爲 left 即可,並將缺省選項設置爲值 1

from treevalue import func_treelize


@func_treelize(mode='left', missing=1)
def gcd(a, b):  # GCD calculation
    while True:
        r = a % b
        a, b = b, r
        if r == 0:
            break

    return a

例如對於圖2所示的 gcd(t1, t2) 例子中,在左優先模式下計算結果如下,如圖5所示


(圖5,t2.b因t1.b的缺失而被忽略,而t2.x.c取缺省值1)

而在 gcd(t2, t1) 例子中,左優先計算結果如下,如圖6所示


(圖6,t1.x.c因t2.x.c的缺失而被忽略,而t1.b取缺省值1)

左優先模式會按照最左樹參數的結構來進行計算,生成的計算結果也將和最左的參數保持一致。但是與外共有模式類似,左優先模式在絕大部分情況下依賴缺省選項的配置,需要確保配置準確無誤且自洽。此外,對於原本滿足交換律的運算,經過左優先模式的樹化後將會失去原有的交換律性質,這一點請務必留意。

繼承選項(inherit)

繼承選項可以通過普通值的繼承機制,讓樹化函數在實際應用中使用起來更加簡潔,也讓樹參數可以和普通參數在樹化後的函數中被混用。在默認情況下,繼承選項是處於開啓狀態的,即等價於如下的代碼

from treevalue import func_treelize


@func_treelize(inherit=True)
def gcd(a, b):  # GCD calculation
    while True:
        r = a % b
        a, b = b, r
        if r == 0:
            break

    return a

因此,有如下的例子 gcd(t1, t2) ,其計算結果如圖7所示


(圖7,t2.x.c和t2.x.d繼承t2.x的值6)

此外顯而易見的是,也可以直接將非樹值直接傳入,和樹參數混用,例如下面的例子 gcd(100, t1) ,其計算結果如圖x所示


(圖8,值100被完全繼承並作爲第一棵樹的全部值)

而當繼承選項被關閉時,則上述兩個例子均會拋出異常,因爲存在值和子樹混用的情況。

從業務邏輯的角度來看,繼承選項可以良好地適應大部分真實存在的值複用情況,且值和子樹混用在大多數業務邏輯上也是有明確意義的。但是當混用在業務邏輯角度上意義不明且需要被顯式地檢測時,則建議關閉繼承選項

缺省選項(missing)

缺省選項可以爲部分鍵值存在缺失的情況提供一個值的補充,主要作用於外共有模式和左優先模式。我們可以通過 missing 參數直接提供值,如下所示

from treevalue import func_treelize, FastTreeValue

@func_treelize(mode='outer', missing=0)
def total(*args):
    return sum(args)

上述的加法函數計算例子如下, total(t1, t2, t3) 計算結果如下圖9所示


(圖9,缺省值0被全面用於填補空缺,並最終計算出了有效的總和)

此外考慮到有些情況下,直接使用值作爲缺省值可能會存在公用同一個對象導致錯誤的情況,因此我們提供了通過傳入生成函數來產生默認值的用法。可以通過 missing 參數傳入值生成器,如下所示

from treevalue import func_treelize, FastTreeValue

@func_treelize(mode='outer', missing=lambda: [])
def append(arr: list, *args):
    for item in args:
        if item:
            arr.append(item)
    return arr

上述的列表追加值計算例子如下, append(t0, t1, t2, t3) 運算結果如下圖10所示


(圖10,每次缺省均會生成新的空列表)

通過缺省選項的有效配置,結合外共有模式和左優先模式,可以有效擴展樹化函數對值缺省情況的處理能力。不過值得注意的是,缺省選項在嚴格模式下無法生效,因爲當檢測到鍵缺失時將會直接拋出異常;以及缺省模式在內共同模式下永遠無法實質上生效,因此樹化函數會針對這一情況拋出一個警告信息。

上升、下沉選項

除了上述的基本機制選項之外,樹化函數還提供了上升(rise)和下沉(subside)選項,以簡化對結構化數據的處理。兩者的功能分別爲:

  • 下沉(subside)——嘗試將參數中頂層結構非樹的對象,提取結構後將結構下沉至樹內,使原函數在運行過程中可以接收到。關於下沉函數的具體細節可以參考之前文章
  • 上升(rise)——嘗試從返回結果樹的葉子節點值中提取共同結構,向上升至樹外,使返回值的邏輯結構可以被外部直接訪問。關於上升函數的具體細節可以參考之前文章

因此我們可以在需要的時候打開這兩個選項,代碼如下,實現的效果是從列表 arr 中查找首個滿足條件值的位置( position ),並統計共有多少個滿足條件的值( cnt

from treevalue import func_treelize, FastTreeValue


@func_treelize(subside=True, rise=True)
def check(arr: list, target):
    position = None
    cnt = 0
    for i, item in enumerate(arr):
        if target(item):
            if position is None:
                position = i
            cnt += 1

    return position, cnt


t1 = FastTreeValue({'a': 2, 'b': 4, 'x': {'c': 7, 'd': 9}})
t2 = FastTreeValue({'a': 4, 'b': 48, 'x': {'c': 2, 'd': 53}})
t3 = FastTreeValue({'a': 9, 'b': -12, 'x': {'c': 3, 'd': 7}})

tr1, tr2 = check([t1, t2, t3], lambda x: x % 2 == 0)

代碼中可以看到三棵樹 t1t2t3 可以直接用列表裝載,在原函數 check 中可以接收到對應位置上的值列表。並且由於 rise 選項的開啓,位置和數量所構成的二元組也會被提取出來,形成兩棵樹,即 tr1tr2 ,如下圖11所示


(圖11,[t1, t2, t3]作爲列表參數,tr1, tr2作爲返回值樹)

此外,上升和下沉選項一個更加有效的使用例子是對 torch.splittorch.stack 函數進行裝飾,代碼如下所示

import torch

from treevalue import func_treelize, TreeValue

stack = func_treelize(subside=True)(torch.stack)
split = func_treelize(rise=True)(torch.split)

trees = [TreeValue({
    'a': torch.randn(2, 4),
    'b': torch.randn(3, 4),
    'x': {'c': torch.randn(2, 1, 3)}
}) for _ in range(10)]

st = stack(trees)  # stack all the trees together
splitted = split(st, [1] * 10)  # split back to trees

# splitted should be equal to trees

其中 st 即爲合併後的樹,而 splitted 爲再次拆分後的樹, splittedtrees 等價。

後續預告

本文主要針對treevalue的核心特性——樹化函數,基於其自身進行了詳細的原理解析,受限於篇幅,本次只着重講述了原生樹化函數本身的原理、特性以及例子。在下一篇中將會針對更多衍生場景進行分析與展示,敬請期待。

同時歡迎瞭解其他OpenDILab的開源項目:https://github.com/opendilab

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章