【源碼閱讀系列】一:GraphSAGE代碼閱讀(1)

0.前言

昨天發了一篇關於GraphSAGE論文的大致講解,今天對源碼進行部分解析,源碼鏈接。作者最原始的訓練代碼是Tensorflow版本的,這是一個PyTorch版本的,恰好最近學習PyTorch,同時也有一段時間不用Tensorflow了,所以就對PyTorch版本的進行解析(其實主要是PyTorch的源碼簡單還少)。代碼可能一次性看不完,畢竟能力有限~~,本文只放置部分關鍵代碼。分析鏈接爲:

https://github.com/TwT520Ly/Code-Reading

1.數據集分析

Cora數據集
代碼只用了Cora數據集的一部分,Cora數據集中樣本是機器學習論文,論文被分爲7類:

  • Case_Based
  • Genetic_Algorithms
  • Neural_Networks
  • Probabilistic_Methods
  • Reinforcement_Learning
  • Rule_Learning
  • Theory

數據集共有2708篇論文,分爲兩個文件:

  • .content
  • .cites

第一個文件形式爲:

<paper_id> <word_attributes>+ <class_label>

分別表示論文的唯一ID,文檔詞的0-1編碼向量,類別標籤;文檔詞中0表示不存在,1表示存在。
第二個文件形式爲:

<ID of cited paper> <ID of citing paper>

分別表示被引用論文和引用論文,即後者引用前者,paper2->paper1。

2.代碼分析

2.1 aggregators.py

實現聚合類,對鄰居信息進行AGGREGATE。

# 如果num_sample設置了具體數字
if not num_sample is None:
     _sample = random.sample
     # 首先對每一個節點的鄰居集合neigh進行遍歷,判斷一下已有鄰居數和採樣數大小,多於採樣數進行抽樣
     samp_neighs = [_set(_sample(to_neigh, 
                     num_sample,
                     )) if len(to_neigh) >= num_sample else to_neigh for to_neigh in to_neighs]
 else:
     samp_neighs = to_neighs

這裏是對一個batch中的每一個節點的鄰接點set進行sample,主要計算量在random.sample,簡單分析一下random.sample,該函數如果指定採樣數爲K,內部會進行K次循環,分別獲取K個元素。

if n <= setsize:
     # An n-length list is smaller than a k-length set
     pool = list(population)
     for i in range(k):         # invariant:  non-selected at [0,n-i)
         j = randbelow(n-i)
         result[i] = pool[j]
         pool[j] = pool[n-i-1]   # move non-selected item into vacancy
 else:
     selected = set()
     selected_add = selected.add
     for i in range(k):
         j = randbelow(n)
         while j in selected:
             j = randbelow(n)
         selected_add(j)
         result[i] = population[j]
 return result

此處通過調用randbelow函數實現,簡單的考慮,如果我要抽取K個元素,那麼是不是隻要從原序列中生成K次隨機下標就可以了?時間複雜度爲O(K)?事實上沒有這麼簡單,如果sample出來的序列需要維持原有的次序,就需要每次randbelow的下標有序插入到已經sample的序列中,搜索代價大致爲O(logN),那麼時間複雜度就是O(NlogN),如果是這樣子的話,那SAGE的sample時間複雜度就會提升到O(MNlogN)。不過上面的代碼中有明顯的一個if-else結構,所以實現方式應該沒有這麼簡單。首先看到判斷條件爲setsize,此變量來源如下:

setsize = 21        # size of a small set minus size of an empty list
if k > 5:
     setsize += 4 ** _ceil(_log(k * 3, 4)) # table size for big sets

這一堆看着就奇怪,莫名其妙的公式(暫時不管,其實和set的內存設定有關係,此處不做詳細說明)~~。反正就是利用K值計算出一個setsize,然後判斷和輸入序列大小n的大小關係,如果n相對較小,就好像是10箇中抽樣9個,採用無放回抽樣算法,那麼每次抽樣後原始序列縮小一個單位,爲了不改變原始輸入序列在內存中數值,將其拷貝至pool列表,並通過尾元素填充被選元素+縮小隨機範圍的方式從邏輯上壓縮pool列表:

pool[j] = pool[n-i-1] 

那麼如果n較大,就會執行else部分代碼,比如1千萬數組中抽取3個元素,採用上述策略效率太低,所以採用放回抽樣+多次重試的策略,如果隨機到的下標已經在之前select到了,就通過while循環進行多次嘗試:

while j in selected:
      j = randbelow(n)

綜上所述,採用混合實現的方式,random.sample的時間複雜度會穩定在O(K)上。
說了這麼多,繼續回到SAGE的代碼,那麼如果當前節點設置的抽樣數爲num_sample,則時間複雜度爲O(num_sample * batch_size)

# *拆解列表後,轉爲爲多個獨立的元素作爲參數給union,union函數進行去重合並
unique_nodes_list = list(set.union(*samp_neighs))
# 節點標號不一定都是從0開始的,創建一個字典,key爲節點ID,value爲節點序號
unique_nodes = {n:i for i,n in enumerate(unique_nodes_list)}
# print(len(nodes), len(unique_nodes), len(samp_neighs))
# nodes表示batch內的節點,unique_nodes表示batch內的節點用到的所有鄰居節點,unique_nodes > nodes
# 創建一個nodes * unique_nodes大小的矩陣
mask = Variable(torch.zeros(len(samp_neighs), len(unique_nodes)))
# 遍歷每一個鄰居集合的每一個元素,並且通過ID(key)獲取到節點對應的序號--列切片
column_indices = [unique_nodes[n] for samp_neigh in samp_neighs for n in samp_neigh]
# 行切片,比如samp_neighs = [{3,5,9}, {2,8}, {2}],行切片爲[0,0,0,1,1,2]
row_indices = [i for i in range(len(samp_neighs)) for j in range(len(samp_neighs[i]))]
# 利用切片創建鄰接矩陣
mask[row_indices, column_indices] = 1

這一堆代碼是爲了構造鄰接矩陣。

# 統計每一個節點的鄰居數量
num_neigh = mask.sum(1, keepdim=True)
# 分比例
mask = mask.div(num_neigh)
# embed_matrix: [n, m]
# n: unique_nodes
# m: dim
if self.cuda:
    embed_matrix = self.features(torch.LongTensor(unique_nodes_list).cuda())
else:
    embed_matrix = self.features(torch.LongTensor(unique_nodes_list))
# mean操作
to_feats = mask.mm(embed_matrix)

這裏就實現了mean方式的AGGREGATE。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章