API
開發了一個基於pytorch的膠囊網絡工具包,以下是它的API:
### 膠囊化層
class Capsulation2D(nn.Module)
# input.shape = (batch_size, channels, height, weight)
# output.shape = (batch_size, out_channels, out_dim_capsule, height, weight)
### 反膠囊化層
class DeCapsulation2D(nn.Module)
# input.shape = (batch_size, channels, dim_capsule, height, weight)
# output.shape = (batch_size, out_channels, height, weight)
### 平化層
class CapFlatten(nn.Module)
# input.shape = (batch_size, channels, dim_capsule, height, weight)
# output.shape = (batch_size, channels * height * weight, dim_capsule) which is (batch_size, num_capsules, dim_capsule)
### 反平化層
class DeCapFlatten(nn.Module)
# input.shape = (batch_size, channels * height * weight, dim_capsule),
# which is (batch_size, num_capsules, dim_capsule)
# output.shape = (batch_size, channels, dim_capsule, height, weight)
### 標量化層
class CapScalarization(nn.Module)
# input.shape = (batch_size, num_capsules, dim_capsule)
# output.shape = (batch_size, num_capsules)
### 膠囊2D卷積層(V1)
class CapConv2dV1(nn.Module)
# input.shape = (batch_size, channels, dim_capsule, height, weight)
# output.shape = (batch_size, out_channels, out_dim_capsule, out_height, out_weight)
### 數字膠囊(路由輸出層)
class CapsuleLayer(nn.Module)
# Dynamic Routing Version
# input.shape = [batch, input_num_capsule, input_dim_capsule]
# output.shape = [batch, num_capsule, 1, dim_capsule]
### 掩碼層
class CapReconMask(nn.Module)
# input.shape = (batch, num_classes, dim_capsules) | (batch, num_capsules, dim_capsules)
# masked.shape = (batch, dim_capsules)
## 工具包
class CapTool():
def one_hot(self, y, num_dim=10):
"""
One Hot Encoding, similar to `torch.eye(num_dim).index_select(dim=0, index=y)`
:param y: N-dim tenser
:param num_dim: do one-hot labeling from `0` to `num_dim-1`
:return: shape = (batch_size, num_dim)
"""
def margin_loss(self, input, target, num_classes=10, m_plus=None, m_minus=None, m_lambda=0.5):
"""
The non-linear activation used in Capsule.
It drives the length of a large vector to near 1 and small vector to 0
input.shape = (batch_size, num_classes)
target.shape = (batch_size, ), type of `LongTensor`, True-Label of classifications
:param input: Predict-ablility of classifications
:param target: True-Label of classifications
:param num_classes: 10
:param m_plus: 0.9
:param m_minus: 0.1
:param m_lambda: 0.5
:return: shape = (1, )
"""
def squash(self, s, dim=-1, constant=1, epsilon=1e-8):
"""
It drives the length of a large vector to near 1 and small vector to 0
:params s: N-dim tenser
:params dim: the dimension to squash
:params constant: (0, 1]
:return: The same shape like `s`
"""
def acc_eval(self, model, test_loader, loss_fn, y_pred_dim=0)
def model_summary(self, model, show_layer_detail=True)
工具包下載
可以到此處獲取ipynb
版的內容,輸出/拷貝爲py
文件即可直接import: 膠囊網絡工具包/capsnet_tool.ipynb@gist
同時,該ipynb
文件的末尾包含了一個測試,基於本工具包實現了 Hinton 提出的 Dynamic-Routing 版本的 CapsNet ,輸出爲py
文件時記得把這部分內容刪去。