[CapsNet]基於pytorch的膠囊網絡工具包

API

論文地址:Dynamic Routing Between Capsules

開發了一個基於pytorch的膠囊網絡工具包,以下是它的API:

### 膠囊化層
class Capsulation2D(nn.Module)
	# input.shape = (batch_size, channels, height, weight)
    # output.shape = (batch_size, out_channels, out_dim_capsule, height, weight)
    
### 反膠囊化層
class DeCapsulation2D(nn.Module)
	# input.shape = (batch_size, channels, dim_capsule, height, weight)
    # output.shape = (batch_size, out_channels, height, weight)
    
### 平化層
class CapFlatten(nn.Module)
    # input.shape = (batch_size, channels, dim_capsule, height, weight)
    # output.shape = (batch_size, channels * height * weight, dim_capsule) which is (batch_size, num_capsules, dim_capsule)


### 反平化層
class DeCapFlatten(nn.Module)
    # input.shape = (batch_size, channels * height * weight, dim_capsule),  
    #     which is (batch_size, num_capsules, dim_capsule)
    # output.shape = (batch_size, channels, dim_capsule, height, weight)


###  標量化層
class CapScalarization(nn.Module)
    # input.shape = (batch_size, num_capsules, dim_capsule)
    # output.shape = (batch_size, num_capsules)


### 膠囊2D卷積層(V1)
class CapConv2dV1(nn.Module)
    # input.shape = (batch_size, channels, dim_capsule, height, weight)
    # output.shape = (batch_size, out_channels, out_dim_capsule, out_height, out_weight)


### 數字膠囊(路由輸出層)
class CapsuleLayer(nn.Module)
    # Dynamic Routing Version 
    # input.shape = [batch, input_num_capsule, input_dim_capsule]  
    # output.shape = [batch, num_capsule, 1, dim_capsule]


### 掩碼層
class CapReconMask(nn.Module)
    # input.shape = (batch, num_classes, dim_capsules) | (batch, num_capsules, dim_capsules)
    # masked.shape = (batch, dim_capsules)
    

## 工具包
class CapTool():
    def one_hot(self, y, num_dim=10):
        """
        One Hot Encoding, similar to `torch.eye(num_dim).index_select(dim=0, index=y)`
        :param y: N-dim tenser
        :param num_dim: do one-hot labeling from `0` to `num_dim-1`
        :return: shape = (batch_size, num_dim)
        """

    def margin_loss(self, input, target, num_classes=10, m_plus=None, m_minus=None, m_lambda=0.5):
        """
        The non-linear activation used in Capsule. 
        It drives the length of a large vector to near 1 and small vector to 0

        input.shape = (batch_size, num_classes)
        target.shape = (batch_size, ), type of `LongTensor`, True-Label of classifications

        :param input: Predict-ablility of classifications
        :param target: True-Label of classifications
        :param num_classes: 10
        :param m_plus: 0.9
        :param m_minus: 0.1
        :param m_lambda: 0.5
        :return: shape = (1, )
        """
    
    def squash(self, s, dim=-1, constant=1, epsilon=1e-8):
        """
        It drives the length of a large vector to near 1 and small vector to 0
        :params s: N-dim tenser
        :params dim: the dimension to squash
        :params constant: (0, 1]
        :return: The same shape like `s`
        """
    
    def acc_eval(self, model, test_loader, loss_fn, y_pred_dim=0)
    
    
    def model_summary(self, model, show_layer_detail=True)

工具包下載

可以到此處獲取ipynb版的內容,輸出/拷貝爲py文件即可直接import: 膠囊網絡工具包/capsnet_tool.ipynb@gist

同時,該ipynb文件的末尾包含了一個測試,基於本工具包實現了 Hinton 提出的 Dynamic-Routing 版本的 CapsNet ,輸出爲py文件時記得把這部分內容刪去。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章