einops方法

einops方法

該方法可以快速實現矩陣的快速變化。

import torch
import torch.nn as nn
from einops import rearrange  # 快速矩陣變化


class TestAttentionQKV:
    def __init__(self, dim=64, heads=8, dim_head=64):
        inner_dim = dim_head * heads
        self.dim = dim
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

    def forward(self, x):
        print(self.to_qkv(x).shape)
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        print(qkv[0].shape)
        print(qkv[1].shape)
        print(qkv[2].shape)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
        print(q.shape)
        print(k.shape)
        print(v.shape)
        print(self.scale)


if __name__ == '__main__':
    input_tensor = torch.rand((100, 20, 64))
    torch.Tensor([1.0, 2.0])
    test_attention = TestAttentionQKV()
    # print(input_tensor)
    test_attention.forward(input_tensor)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章