基於切比雪夫多項式的簡單GCN網絡

利用論文《SEMI-SUPERVISED CLASSIFICATION WITH GRAPH CONVOLUTIONAL NETWORKS》中的原理進行簡單的GCN測試,具體原理可看這篇論文。

import torch
import torch.nn as nn
import torch.nn.functional as F

import networkx as nx


def normalize(A, symmetric=True):
	# A = A+I
	A = A + torch.eye(A.size(0))
	d = A.sum(1)
	if symmetric:
		#D = D^-1/2
		D = torch.diag(torch.pow(d , -0.5))
		return D.mm(A).mm(D)
	else:
		# D=D^-1
		D =torch.diag(torch.pow(d,-1))
		return D.mm(A)


class GCN(nn.Module):
	'''
	Z = AXW
	'''
	def __init__(self , A, dim_in , dim_out):
		super(GCN,self).__init__()
		self.A = A
		self.fc1 = nn.Linear(dim_in ,dim_in,bias=False)
		self.fc2 = nn.Linear(dim_in,dim_in//2,bias=False)
		self.fc3 = nn.Linear(dim_in//2,dim_out,bias=False)

	def forward(self, X):
		'''
		計算三層gcn
		'''
		X = F.relu(self.fc1(self.A.mm(X)))
		X = F.relu(self.fc2(self.A.mm(X)))
		X = self.fc3(self.A.mm(X))
		return F.softmax(X, dim=1)


#獲得空手道俱樂部數據
G = nx.karate_club_graph()
A = nx.adjacency_matrix(G).todense()
A_normed = normalize(torch.FloatTensor(A.astype(int)),True)

N = len(A)
X_dim = N

# node features
X = torch.eye(N, X_dim)
# 分類結果,0 or 1
Y = torch.zeros(N, 1).long()
Y_mask = torch.zeros(N, 1, dtype=torch.uint8)

# 總共有2類,分別給第一類的第一個樣本和第二類的最後一個樣本設置ground truth label
# 利用這兩個帶標籤的樣本進行半監督學習
Y[0][0] = 0
Y[N-1][0] = 1

# 指示哪些樣本已經提前設置了標籤
Y_mask[0][0] = 1
Y_mask[N-1][0] = 1

# ground truth
Real = torch.zeros(34, dtype=torch.long)
for i in [1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 17, 18, 20, 22]:
	Real[i-1] = 0
for i in [9, 10, 15, 16, 19, 21, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34]:
	Real[i-1] = 1

gcn = GCN(A_normed, X_dim, 2)
gd = torch.optim.Adam(gcn.parameters())

for i in range(300):
	y_pred = gcn(X)
	#下面兩行計算cross entropy
	loss = (-y_pred.log().gather(1, Y.view(-1, 1)))
	loss = loss.masked_select(Y_mask).mean()   # 返回mask標記爲1的樣本對應的損失
	gd.zero_grad()
	loss.backward()
	gd.step()

	if i % 20 == 0:
		_, mi = y_pred.max(1)
		print("loss: ", loss.item())
		print(mi)
		print("acc: ", (mi == Real).float().mean().item())
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章