Spektral:使用TF2實現經典GNN的開源庫

簡介

Spektral工具還發表了論文:
《Graph Neural Networks in TensorFlow and Keras with Spektral》
https://arxiv.org/abs/2006.12138

github地址:https://github.com/danielegrattarola/spektral/

在本文中,我們介紹了 Spektral,這是一個開源 Python 庫,用於使用 TensorFlow 和 Keras 應用程序編程接口構建圖神經網絡。Spektral 實現了大量的圖深度學習方法,包括消息傳遞和池化運算符,以及用於處理圖和加載流行基準數據集的實用程序。這個庫的目的是爲創建圖神經網絡提供基本的構建塊,重點是 Keras 所基於的用戶友好性和快速原型設計的指導原則。因此,Spektral 適合絕對的初學者和專業的深度學習從業者。

主要網絡

Spektral 實現了一些主流的圖深度學習層,包括:

安裝

pip安裝:

pip install spektral

源碼安裝:

git clone https://github.com/danielegrattarola/spektral.git
cd spektral
python setup.py install  # Or 'pip install .'

Spektral實現GCN

對於TF愛好者很友好:

import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.optimizers import Adam

from spektral.data.loaders import SingleLoader
from spektral.datasets.citation import Citation
from spektral.layers import GCNConv
from spektral.models.gcn import GCN
from spektral.transforms import AdjToSpTensor, LayerPreprocess

learning_rate = 1e-2
seed = 0
epochs = 200
patience = 10
data = "cora"

tf.random.set_seed(seed=seed)  # make weight initialization reproducible

# Load data
dataset = Citation(
    data, normalize_x=True, transforms=[LayerPreprocess(GCNConv), AdjToSpTensor()]
)


# We convert the binary masks to sample weights so that we can compute the
# average loss over the nodes (following original implementation by
# Kipf & Welling)
def mask_to_weights(mask):
    return mask.astype(np.float32) / np.count_nonzero(mask)


weights_tr, weights_va, weights_te = (
    mask_to_weights(mask)
    for mask in (dataset.mask_tr, dataset.mask_va, dataset.mask_te)
)

model = GCN(n_labels=dataset.n_labels, n_input_channels=dataset.n_node_features)
model.compile(
    optimizer=Adam(learning_rate),
    loss=CategoricalCrossentropy(reduction="sum"),
    weighted_metrics=["acc"],
)

# Train model
loader_tr = SingleLoader(dataset, sample_weights=weights_tr)
loader_va = SingleLoader(dataset, sample_weights=weights_va)
model.fit(
    loader_tr.load(),
    steps_per_epoch=loader_tr.steps_per_epoch,
    validation_data=loader_va.load(),
    validation_steps=loader_va.steps_per_epoch,
    epochs=epochs,
    callbacks=[EarlyStopping(patience=patience, restore_best_weights=True)],
)

# Evaluate model
print("Evaluating model.")
loader_te = SingleLoader(dataset, sample_weights=weights_te)
eval_results = model.evaluate(loader_te.load(), steps=loader_te.steps_per_epoch)
print("Done.\n" "Test loss: {}\n" "Test accuracy: {}".format(*eval_results))
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章