API
开发了一个基于pytorch的胶囊网络工具包,以下是它的API:
### 胶囊化层
class Capsulation2D(nn.Module)
# input.shape = (batch_size, channels, height, weight)
# output.shape = (batch_size, out_channels, out_dim_capsule, height, weight)
### 反胶囊化层
class DeCapsulation2D(nn.Module)
# input.shape = (batch_size, channels, dim_capsule, height, weight)
# output.shape = (batch_size, out_channels, height, weight)
### 平化层
class CapFlatten(nn.Module)
# input.shape = (batch_size, channels, dim_capsule, height, weight)
# output.shape = (batch_size, channels * height * weight, dim_capsule) which is (batch_size, num_capsules, dim_capsule)
### 反平化层
class DeCapFlatten(nn.Module)
# input.shape = (batch_size, channels * height * weight, dim_capsule),
# which is (batch_size, num_capsules, dim_capsule)
# output.shape = (batch_size, channels, dim_capsule, height, weight)
### 标量化层
class CapScalarization(nn.Module)
# input.shape = (batch_size, num_capsules, dim_capsule)
# output.shape = (batch_size, num_capsules)
### 胶囊2D卷积层(V1)
class CapConv2dV1(nn.Module)
# input.shape = (batch_size, channels, dim_capsule, height, weight)
# output.shape = (batch_size, out_channels, out_dim_capsule, out_height, out_weight)
### 数字胶囊(路由输出层)
class CapsuleLayer(nn.Module)
# Dynamic Routing Version
# input.shape = [batch, input_num_capsule, input_dim_capsule]
# output.shape = [batch, num_capsule, 1, dim_capsule]
### 掩码层
class CapReconMask(nn.Module)
# input.shape = (batch, num_classes, dim_capsules) | (batch, num_capsules, dim_capsules)
# masked.shape = (batch, dim_capsules)
## 工具包
class CapTool():
def one_hot(self, y, num_dim=10):
"""
One Hot Encoding, similar to `torch.eye(num_dim).index_select(dim=0, index=y)`
:param y: N-dim tenser
:param num_dim: do one-hot labeling from `0` to `num_dim-1`
:return: shape = (batch_size, num_dim)
"""
def margin_loss(self, input, target, num_classes=10, m_plus=None, m_minus=None, m_lambda=0.5):
"""
The non-linear activation used in Capsule.
It drives the length of a large vector to near 1 and small vector to 0
input.shape = (batch_size, num_classes)
target.shape = (batch_size, ), type of `LongTensor`, True-Label of classifications
:param input: Predict-ablility of classifications
:param target: True-Label of classifications
:param num_classes: 10
:param m_plus: 0.9
:param m_minus: 0.1
:param m_lambda: 0.5
:return: shape = (1, )
"""
def squash(self, s, dim=-1, constant=1, epsilon=1e-8):
"""
It drives the length of a large vector to near 1 and small vector to 0
:params s: N-dim tenser
:params dim: the dimension to squash
:params constant: (0, 1]
:return: The same shape like `s`
"""
def acc_eval(self, model, test_loader, loss_fn, y_pred_dim=0)
def model_summary(self, model, show_layer_detail=True)
工具包下载
可以到此处获取ipynb
版的内容,输出/拷贝为py
文件即可直接import: 胶囊网络工具包/capsnet_tool.ipynb@gist
同时,该ipynb
文件的末尾包含了一个测试,基于本工具包实现了 Hinton 提出的 Dynamic-Routing 版本的 CapsNet ,输出为py
文件时记得把这部分内容删去。