我用numpy實現了VIT,手寫vision transformer, 可在樹莓派上運行,在hugging face上訓練模型保存參數成numpy格式,純numpy實現

先複製一點知乎上的內容

 

按照上面的流程圖,一個ViT block可以分爲以下幾個步驟

(1) patch embedding:例如輸入圖片大小爲224x224,將圖片分爲固定大小的patch,patch大小爲16x16,則每張圖像會生成224x224/16x16=196個patch,即輸入序列長度爲196,每個patch維度16x16x3=768,線性投射層的維度爲768xN (N=768),因此輸入通過線性投射層之後的維度依然爲196x768,即一共有196個token,每個token的維度是768。這裏還需要加上一個特殊字符cls,因此最終的維度是197x768。到目前爲止,已經通過patch embedding將一個視覺問題轉化爲了一個seq2seq問題

(2) positional encoding(standard learnable 1D position embeddings):ViT同樣需要加入位置編碼,位置編碼可以理解爲一張表,表一共有N行,N的大小和輸入序列長度相同,每一行代表一個向量,向量的維度和輸入序列embedding的維度相同(768)。注意位置編碼的操作是sum,而不是concat。加入位置編碼信息之後,維度依然是197x768

(3) LN/multi-head attention/LN:LN輸出維度依然是197x768。多頭自注意力時,先將輸入映射到q,k,v,如果只有一個頭,qkv的維度都是197x768,如果有12個頭(768/12=64),則qkv的維度是197x64,一共有12組qkv,最後再將12組qkv的輸出拼接起來,輸出維度是197x768,然後在過一層LN,維度依然是197x768

(4) MLP:將維度放大再縮小回去,197x768放大爲197x3072,再縮小變爲197x768

一個block之後維度依然和輸入相同,都是197x768,因此可以堆疊多個block。最後會將特殊字符cls對應的輸出 Z0 作爲encoder的最終輸出 ,代表最終的image presentation(另一種做法是不加cls字符,對所有的tokens的輸出做一個平均),如下圖公式(4),後面接一個MLP進行圖片分類

vit 的 numpy 實現代碼,可以直接看懂各個部分的細節實現 ,和bert有一些不一樣,除了embedding層不一樣之外,還有模型結構有有些不同,主要是layer_normalization放在了attention層和feed_forword層之前,bert都是放在之後

import numpy as np
import os
from PIL import Image

# 加載保存的模型數據
model_data = np.load('vit_model_params.npz')
for i in model_data:
    # print(i)
    print(i,model_data[i].shape)

patch_embedding_weight = model_data["vit.embeddings.patch_embeddings.projection.weight"]
patch_embedding_bias = model_data["vit.embeddings.patch_embeddings.projection.bias"]
position_embeddings = model_data["vit.embeddings.position_embeddings"]
cls_token_embeddings = model_data["vit.embeddings.cls_token"]

def patch_embedding(images):
    # 卷積核大小
    kernel_size = 16    
    return conv2d(images, patch_embedding_weight,patch_embedding_bias,stride=kernel_size)

def position_embedding():
    return position_embeddings

def model_input(images):
    
    patch_embedded = np.transpose(patch_embedding(images).reshape([1,768,-1]), (0, 2, 1))

    patch_embedded = np.concatenate([cls_token_embeddings,patch_embedded],axis=1)

    # position_ids = np.array(range(patch_embedded.shape[1]))  # 位置id
    # 位置嵌入矩陣,形狀爲 (max_position, embedding_size)
    position_embedded = position_embedding()

    embedding_output = patch_embedded + position_embedded

    return embedding_output

def softmax(x, axis=None):
    # e_x = np.exp(x).astype(np.float32) #  
    e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    sum_ex = np.sum(e_x, axis=axis,keepdims=True).astype(np.float32)
    return e_x / sum_ex

def conv2d(images,weight,bias,stride=1,padding=0):
    # 卷積操作
    N, C, H, W = images.shape
    F, _, HH, WW = weight.shape
    # 計算卷積後的輸出尺寸
    H_out = (H - HH + 2 * padding) // stride + 1
    W_out = (W - WW + 2 * padding) // stride + 1
    # 初始化卷積層輸出
    out = np.zeros((N, F, H_out, W_out))
    # 執行卷積運算
    for i in range(H_out):
        for j in range(W_out):
            # 提取當前卷積窗口
            window = images[:, :, i * stride:i * stride + HH, j * stride:j * stride + WW]
            # 執行卷積運算
            out[:, :, i, j] = np.sum(window * weight, axis=(1, 2, 3)) + bias
    # 輸出結果
    # print("卷積層輸出尺寸:", out.shape)
    return out

def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.shape[-1]
    scores = np.matmul(Q, K.transpose(0, 2, 1)) / np.sqrt(d_k)
    if mask is not None:
        scores = np.where(mask, scores, np.full_like(scores, -np.inf))
    attention_weights = softmax(scores, axis=-1)
    # print(attention_weights)
    # print(np.sum(attention_weights,axis=-1))
    output = np.matmul(attention_weights, V)
    return output, attention_weights

def multihead_attention(input, num_heads,W_Q,B_Q,W_K,B_K,W_V,B_V,W_O,B_O):

    q = np.matmul(input, W_Q.T)+B_Q
    k = np.matmul(input, W_K.T)+B_K
    v = np.matmul(input, W_V.T)+B_V

    # 分割輸入爲多個頭
    q = np.split(q, num_heads, axis=-1)
    k = np.split(k, num_heads, axis=-1)
    v = np.split(v, num_heads, axis=-1)

    outputs = []
    for q_,k_,v_ in zip(q,k,v):
        output, attention_weights = scaled_dot_product_attention(q_, k_, v_)
        outputs.append(output)
    outputs = np.concatenate(outputs, axis=-1)
    outputs = np.matmul(outputs, W_O.T)+B_O
    return outputs

def layer_normalization(x, weight, bias, eps=1e-12):
    mean = np.mean(x, axis=-1, keepdims=True)
    variance = np.var(x, axis=-1, keepdims=True)
    std = np.sqrt(variance + eps)
    normalized_x = (x - mean) / std
    output = weight * normalized_x + bias
    return output

def feed_forward_layer(inputs, weight, bias, activation='relu'):
    linear_output = np.matmul(inputs,weight) + bias
    
    if activation == 'relu':
        activated_output = np.maximum(0, linear_output)  # ReLU激活函數
    elif activation == 'gelu':
        activated_output = 0.5 * linear_output * (1 + np.tanh(np.sqrt(2 / np.pi) * (linear_output + 0.044715 * np.power(linear_output, 3))))  # GELU激活函數
    
    elif activation == "tanh" :
        activated_output = np.tanh(linear_output)
    else:
        activated_output = linear_output  # 無激活函數
    
    return activated_output


def residual_connection(inputs, residual):
    # 殘差連接
    residual_output = inputs + residual
    return residual_output

def vit(input,num_heads=12):

    for i in range(12):
        # 調用多頭自注意力函數
        W_Q = model_data['vit.encoder.layer.{}.attention.attention.query.weight'.format(i)]
        B_Q = model_data['vit.encoder.layer.{}.attention.attention.query.bias'.format(i)]
        W_K = model_data['vit.encoder.layer.{}.attention.attention.key.weight'.format(i)]
        B_K = model_data['vit.encoder.layer.{}.attention.attention.key.bias'.format(i)]
        W_V = model_data['vit.encoder.layer.{}.attention.attention.value.weight'.format(i)]
        B_V = model_data['vit.encoder.layer.{}.attention.attention.value.bias'.format(i)]
        W_O = model_data['vit.encoder.layer.{}.attention.output.dense.weight'.format(i)]
        B_O = model_data['vit.encoder.layer.{}.attention.output.dense.bias'.format(i)]
        intermediate_weight = model_data['vit.encoder.layer.{}.intermediate.dense.weight'.format(i)]
        intermediate_bias = model_data['vit.encoder.layer.{}.intermediate.dense.bias'.format(i)]
        dense_weight = model_data['vit.encoder.layer.{}.output.dense.weight'.format(i)]
        dense_bias = model_data['vit.encoder.layer.{}.output.dense.bias'.format(i)]
        LayerNorm_before_weight = model_data['vit.encoder.layer.{}.layernorm_before.weight'.format(i)]
        LayerNorm_before_bias = model_data['vit.encoder.layer.{}.layernorm_before.bias'.format(i)]
        LayerNorm_after_weight = model_data['vit.encoder.layer.{}.layernorm_after.weight'.format(i)]
        LayerNorm_after_bias = model_data['vit.encoder.layer.{}.layernorm_after.bias'.format(i)]

        output = layer_normalization(input,LayerNorm_before_weight,LayerNorm_before_bias) 
        output = multihead_attention(output, num_heads,W_Q,B_Q,W_K,B_K,W_V,B_V,W_O,B_O)
        output1 = residual_connection(input,output)
           #這裏和模型輸出一致

        output = layer_normalization(output1,LayerNorm_after_weight,LayerNorm_after_bias)   #一致
        output = feed_forward_layer(output, intermediate_weight.T, intermediate_bias, activation='gelu')
        output = feed_forward_layer(output, dense_weight.T, dense_bias, activation='')
        output2 = residual_connection(output1,output)
    
        input = output2

    bert_pooler_dense_weight = model_data['vit.layernorm.weight']
    bert_pooler_dense_bias = model_data['vit.layernorm.bias']
    output = layer_normalization(output2[:,0],bert_pooler_dense_weight,bert_pooler_dense_bias )    #一致
    classifier_weight = model_data['classifier.weight']
    classifier_bias = model_data['classifier.bias']
    output = feed_forward_layer(output,classifier_weight.T,classifier_bias,activation="" )    #一致
    output = np.argmax(output,axis=-1)
    return output

folder_path = './cifar10'  # 替換爲圖片所在的文件夾路徑
def infer_images_in_folder(folder_path):
    for file_name in os.listdir(folder_path):
        file_path = os.path.join(folder_path, file_name)
        if os.path.isfile(file_path) and file_name.endswith(('.jpg', '.jpeg', '.png')):
            image = Image.open(file_path)
            image = image.resize((224, 224))
            label = file_name.split(".")[0].split("_")[1]
            image = np.array(image)/255.0
            image = np.transpose(image, (2, 0, 1))
            image = np.expand_dims(image,axis=0)
            print("file_path:",file_path,"img size:",image.shape,"label:",label)
            input = model_input(image)
            predicted_class = vit(input)
            print('Predicted class:', predicted_class)


if __name__ == "__main__":

    infer_images_in_folder(folder_path)

結果:

file_path: ./cifar10/8619_5.jpg img size: (1, 3, 224, 224) label: 5
Predicted class: [5]
file_path: ./cifar10/6042_6.jpg img size: (1, 3, 224, 224) label: 6
Predicted class: [6]
file_path: ./cifar10/6801_6.jpg img size: (1, 3, 224, 224) label: 6
Predicted class: [6]
file_path: ./cifar10/7946_1.jpg img size: (1, 3, 224, 224) label: 1
Predicted class: [1]
file_path: ./cifar10/6925_2.jpg img size: (1, 3, 224, 224) label: 2
Predicted class: [2]
file_path: ./cifar10/6007_6.jpg img size: (1, 3, 224, 224) label: 6
Predicted class: [6]
file_path: ./cifar10/7903_1.jpg img size: (1, 3, 224, 224) label: 1
Predicted class: [1]
file_path: ./cifar10/7064_5.jpg img size: (1, 3, 224, 224) label: 5
Predicted class: [5]
file_path: ./cifar10/2713_8.jpg img size: (1, 3, 224, 224) label: 8
Predicted class: [8]
file_path: ./cifar10/8575_9.jpg img size: (1, 3, 224, 224) label: 9
Predicted class: [9]
file_path: ./cifar10/1985_6.jpg img size: (1, 3, 224, 224) label: 6
Predicted class: [6]
file_path: ./cifar10/5312_5.jpg img size: (1, 3, 224, 224) label: 5
Predicted class: [5]
file_path: ./cifar10/593_6.jpg img size: (1, 3, 224, 224) label: 6
Predicted class: [6]
file_path: ./cifar10/8093_7.jpg img size: (1, 3, 224, 224) label: 7
Predicted class: [7]
file_path: ./cifar10/6862_5.jpg img size: (1, 3, 224, 224) label: 5

 

模型參數:

vit.embeddings.cls_token (1, 1, 768)
vit.embeddings.position_embeddings (1, 197, 768)
vit.embeddings.patch_embeddings.projection.weight (768, 3, 16, 16)
vit.embeddings.patch_embeddings.projection.bias (768,)
vit.encoder.layer.0.attention.attention.query.weight (768, 768)
vit.encoder.layer.0.attention.attention.query.bias (768,)
vit.encoder.layer.0.attention.attention.key.weight (768, 768)
vit.encoder.layer.0.attention.attention.key.bias (768,)
vit.encoder.layer.0.attention.attention.value.weight (768, 768)
vit.encoder.layer.0.attention.attention.value.bias (768,)
vit.encoder.layer.0.attention.output.dense.weight (768, 768)
vit.encoder.layer.0.attention.output.dense.bias (768,)
vit.encoder.layer.0.intermediate.dense.weight (3072, 768)
vit.encoder.layer.0.intermediate.dense.bias (3072,)
vit.encoder.layer.0.output.dense.weight (768, 3072)
vit.encoder.layer.0.output.dense.bias (768,)
vit.encoder.layer.0.layernorm_before.weight (768,)
vit.encoder.layer.0.layernorm_before.bias (768,)
vit.encoder.layer.0.layernorm_after.weight (768,)
vit.encoder.layer.0.layernorm_after.bias (768,)
vit.encoder.layer.1.attention.attention.query.weight (768, 768)
vit.encoder.layer.1.attention.attention.query.bias (768,)
vit.encoder.layer.1.attention.attention.key.weight (768, 768)
vit.encoder.layer.1.attention.attention.key.bias (768,)
vit.encoder.layer.1.attention.attention.value.weight (768, 768)
vit.encoder.layer.1.attention.attention.value.bias (768,)
vit.encoder.layer.1.attention.output.dense.weight (768, 768)
vit.encoder.layer.1.attention.output.dense.bias (768,)
vit.encoder.layer.1.intermediate.dense.weight (3072, 768)
vit.encoder.layer.1.intermediate.dense.bias (3072,)
vit.encoder.layer.1.output.dense.weight (768, 3072)
vit.encoder.layer.1.output.dense.bias (768,)
vit.encoder.layer.1.layernorm_before.weight (768,)
vit.encoder.layer.1.layernorm_before.bias (768,)
vit.encoder.layer.1.layernorm_after.weight (768,)
vit.encoder.layer.1.layernorm_after.bias (768,)
vit.encoder.layer.2.attention.attention.query.weight (768, 768)
vit.encoder.layer.2.attention.attention.query.bias (768,)
vit.encoder.layer.2.attention.attention.key.weight (768, 768)
vit.encoder.layer.2.attention.attention.key.bias (768,)
vit.encoder.layer.2.attention.attention.value.weight (768, 768)
vit.encoder.layer.2.attention.attention.value.bias (768,)
vit.encoder.layer.2.attention.output.dense.weight (768, 768)
vit.encoder.layer.2.attention.output.dense.bias (768,)
vit.encoder.layer.2.intermediate.dense.weight (3072, 768)
vit.encoder.layer.2.intermediate.dense.bias (3072,)
vit.encoder.layer.2.output.dense.weight (768, 3072)
vit.encoder.layer.2.output.dense.bias (768,)
vit.encoder.layer.2.layernorm_before.weight (768,)
vit.encoder.layer.2.layernorm_before.bias (768,)
vit.encoder.layer.2.layernorm_after.weight (768,)
vit.encoder.layer.2.layernorm_after.bias (768,)
vit.encoder.layer.3.attention.attention.query.weight (768, 768)
vit.encoder.layer.3.attention.attention.query.bias (768,)
vit.encoder.layer.3.attention.attention.key.weight (768, 768)
vit.encoder.layer.3.attention.attention.key.bias (768,)
vit.encoder.layer.3.attention.attention.value.weight (768, 768)
vit.encoder.layer.3.attention.attention.value.bias (768,)
vit.encoder.layer.3.attention.output.dense.weight (768, 768)
vit.encoder.layer.3.attention.output.dense.bias (768,)
vit.encoder.layer.3.intermediate.dense.weight (3072, 768)
vit.encoder.layer.3.intermediate.dense.bias (3072,)
vit.encoder.layer.3.output.dense.weight (768, 3072)
vit.encoder.layer.3.output.dense.bias (768,)
vit.encoder.layer.3.layernorm_before.weight (768,)
vit.encoder.layer.3.layernorm_before.bias (768,)
vit.encoder.layer.3.layernorm_after.weight (768,)
vit.encoder.layer.3.layernorm_after.bias (768,)
vit.encoder.layer.4.attention.attention.query.weight (768, 768)
vit.encoder.layer.4.attention.attention.query.bias (768,)
vit.encoder.layer.4.attention.attention.key.weight (768, 768)
vit.encoder.layer.4.attention.attention.key.bias (768,)
vit.encoder.layer.4.attention.attention.value.weight (768, 768)
vit.encoder.layer.4.attention.attention.value.bias (768,)
vit.encoder.layer.4.attention.output.dense.weight (768, 768)
vit.encoder.layer.4.attention.output.dense.bias (768,)
vit.encoder.layer.4.intermediate.dense.weight (3072, 768)
vit.encoder.layer.4.intermediate.dense.bias (3072,)
vit.encoder.layer.4.output.dense.weight (768, 3072)
vit.encoder.layer.4.output.dense.bias (768,)
vit.encoder.layer.4.layernorm_before.weight (768,)
vit.encoder.layer.4.layernorm_before.bias (768,)
vit.encoder.layer.4.layernorm_after.weight (768,)
vit.encoder.layer.4.layernorm_after.bias (768,)
vit.encoder.layer.5.attention.attention.query.weight (768, 768)
vit.encoder.layer.5.attention.attention.query.bias (768,)
vit.encoder.layer.5.attention.attention.key.weight (768, 768)
vit.encoder.layer.5.attention.attention.key.bias (768,)
vit.encoder.layer.5.attention.attention.value.weight (768, 768)
vit.encoder.layer.5.attention.attention.value.bias (768,)
vit.encoder.layer.5.attention.output.dense.weight (768, 768)
vit.encoder.layer.5.attention.output.dense.bias (768,)
vit.encoder.layer.5.intermediate.dense.weight (3072, 768)
vit.encoder.layer.5.intermediate.dense.bias (3072,)
vit.encoder.layer.5.output.dense.weight (768, 3072)
vit.encoder.layer.5.output.dense.bias (768,)
vit.encoder.layer.5.layernorm_before.weight (768,)
vit.encoder.layer.5.layernorm_before.bias (768,)
vit.encoder.layer.5.layernorm_after.weight (768,)
vit.encoder.layer.5.layernorm_after.bias (768,)
vit.encoder.layer.6.attention.attention.query.weight (768, 768)
vit.encoder.layer.6.attention.attention.query.bias (768,)
vit.encoder.layer.6.attention.attention.key.weight (768, 768)
vit.encoder.layer.6.attention.attention.key.bias (768,)
vit.encoder.layer.6.attention.attention.value.weight (768, 768)
vit.encoder.layer.6.attention.attention.value.bias (768,)
vit.encoder.layer.6.attention.output.dense.weight (768, 768)
vit.encoder.layer.6.attention.output.dense.bias (768,)
vit.encoder.layer.6.intermediate.dense.weight (3072, 768)
vit.encoder.layer.6.intermediate.dense.bias (3072,)
vit.encoder.layer.6.output.dense.weight (768, 3072)
vit.encoder.layer.6.output.dense.bias (768,)
vit.encoder.layer.6.layernorm_before.weight (768,)
vit.encoder.layer.6.layernorm_before.bias (768,)
vit.encoder.layer.6.layernorm_after.weight (768,)
vit.encoder.layer.6.layernorm_after.bias (768,)
vit.encoder.layer.7.attention.attention.query.weight (768, 768)
vit.encoder.layer.7.attention.attention.query.bias (768,)
vit.encoder.layer.7.attention.attention.key.weight (768, 768)
vit.encoder.layer.7.attention.attention.key.bias (768,)
vit.encoder.layer.7.attention.attention.value.weight (768, 768)
vit.encoder.layer.7.attention.attention.value.bias (768,)
vit.encoder.layer.7.attention.output.dense.weight (768, 768)
vit.encoder.layer.7.attention.output.dense.bias (768,)
vit.encoder.layer.7.intermediate.dense.weight (3072, 768)
vit.encoder.layer.7.intermediate.dense.bias (3072,)
vit.encoder.layer.7.output.dense.weight (768, 3072)
vit.encoder.layer.7.output.dense.bias (768,)
vit.encoder.layer.7.layernorm_before.weight (768,)
vit.encoder.layer.7.layernorm_before.bias (768,)
vit.encoder.layer.7.layernorm_after.weight (768,)
vit.encoder.layer.7.layernorm_after.bias (768,)
vit.encoder.layer.8.attention.attention.query.weight (768, 768)
vit.encoder.layer.8.attention.attention.query.bias (768,)
vit.encoder.layer.8.attention.attention.key.weight (768, 768)
vit.encoder.layer.8.attention.attention.key.bias (768,)
vit.encoder.layer.8.attention.attention.value.weight (768, 768)
vit.encoder.layer.8.attention.attention.value.bias (768,)
vit.encoder.layer.8.attention.output.dense.weight (768, 768)
vit.encoder.layer.8.attention.output.dense.bias (768,)
vit.encoder.layer.8.intermediate.dense.weight (3072, 768)
vit.encoder.layer.8.intermediate.dense.bias (3072,)
vit.encoder.layer.8.output.dense.weight (768, 3072)
vit.encoder.layer.8.output.dense.bias (768,)
vit.encoder.layer.8.layernorm_before.weight (768,)
vit.encoder.layer.8.layernorm_before.bias (768,)
vit.encoder.layer.8.layernorm_after.weight (768,)
vit.encoder.layer.8.layernorm_after.bias (768,)
vit.encoder.layer.9.attention.attention.query.weight (768, 768)
vit.encoder.layer.9.attention.attention.query.bias (768,)
vit.encoder.layer.9.attention.attention.key.weight (768, 768)
vit.encoder.layer.9.attention.attention.key.bias (768,)
vit.encoder.layer.9.attention.attention.value.weight (768, 768)
vit.encoder.layer.9.attention.attention.value.bias (768,)
vit.encoder.layer.9.attention.output.dense.weight (768, 768)
vit.encoder.layer.9.attention.output.dense.bias (768,)
vit.encoder.layer.9.intermediate.dense.weight (3072, 768)
vit.encoder.layer.9.intermediate.dense.bias (3072,)
vit.encoder.layer.9.output.dense.weight (768, 3072)
vit.encoder.layer.9.output.dense.bias (768,)
vit.encoder.layer.9.layernorm_before.weight (768,)
vit.encoder.layer.9.layernorm_before.bias (768,)
vit.encoder.layer.9.layernorm_after.weight (768,)
vit.encoder.layer.9.layernorm_after.bias (768,)
vit.encoder.layer.10.attention.attention.query.weight (768, 768)
vit.encoder.layer.10.attention.attention.query.bias (768,)
vit.encoder.layer.10.attention.attention.key.weight (768, 768)
vit.encoder.layer.10.attention.attention.key.bias (768,)
vit.encoder.layer.10.attention.attention.value.weight (768, 768)
vit.encoder.layer.10.attention.attention.value.bias (768,)
vit.encoder.layer.10.attention.output.dense.weight (768, 768)
vit.encoder.layer.10.attention.output.dense.bias (768,)
vit.encoder.layer.10.intermediate.dense.weight (3072, 768)
vit.encoder.layer.10.intermediate.dense.bias (3072,)
vit.encoder.layer.10.output.dense.weight (768, 3072)
vit.encoder.layer.10.output.dense.bias (768,)
vit.encoder.layer.10.layernorm_before.weight (768,)
vit.encoder.layer.10.layernorm_before.bias (768,)
vit.encoder.layer.10.layernorm_after.weight (768,)
vit.encoder.layer.10.layernorm_after.bias (768,)
vit.encoder.layer.11.attention.attention.query.weight (768, 768)
vit.encoder.layer.11.attention.attention.query.bias (768,)
vit.encoder.layer.11.attention.attention.key.weight (768, 768)
vit.encoder.layer.11.attention.attention.key.bias (768,)
vit.encoder.layer.11.attention.attention.value.weight (768, 768)
vit.encoder.layer.11.attention.attention.value.bias (768,)
vit.encoder.layer.11.attention.output.dense.weight (768, 768)
vit.encoder.layer.11.attention.output.dense.bias (768,)
vit.encoder.layer.11.intermediate.dense.weight (3072, 768)
vit.encoder.layer.11.intermediate.dense.bias (3072,)
vit.encoder.layer.11.output.dense.weight (768, 3072)
vit.encoder.layer.11.output.dense.bias (768,)
vit.encoder.layer.11.layernorm_before.weight (768,)
vit.encoder.layer.11.layernorm_before.bias (768,)
vit.encoder.layer.11.layernorm_after.weight (768,)
vit.encoder.layer.11.layernorm_after.bias (768,)
vit.layernorm.weight (768,)
vit.layernorm.bias (768,)
classifier.weight (10, 768)
classifier.bias (10,)

 

hungging face模型訓練代碼 對cifar10訓練,保存模型參數爲numpy格式,方便numpy的模型加載:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from transformers import ViTModel, ViTForImageClassification
from tqdm import tqdm
import numpy as np

# 設置隨機種子
torch.manual_seed(42)

# 定義超參數
batch_size = 64
num_epochs = 1
learning_rate = 1e-4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 數據預處理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# 加載CIFAR-10數據集
train_dataset = CIFAR10(root='/data/xinyuuliu/datas', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='/data/xinyuuliu/datas', train=False, download=True, transform=transform)

# 創建數據加載器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# 加載預訓練的ViT模型
vit_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)

# 替換分類頭
num_classes = 10
# vit_model.config.classifier = 'mlp'
# vit_model.config.num_labels = num_classes
vit_model.classifier = nn.Linear(vit_model.config.hidden_size, num_classes).to(device)


# parameters = list(vit_model.parameters())
# for x in parameters[:-1]:
#     x.requires_grad = False


# 定義損失函數和優化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vit_model.parameters(), lr=learning_rate)

# 微調ViT模型
for epoch in range(num_epochs):
    print("epoch:",epoch)
    vit_model.train()
    train_loss = 0.0
    train_correct = 0

    bar = tqdm(train_loader,total=len(train_loader))
    for images, labels in bar:
        images = images.to(device)
        labels = labels.to(device)

        # 前向傳播
        outputs = vit_model(images)
        loss = criterion(outputs.logits, labels)

        # 反向傳播和優化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = torch.max(outputs.logits, 1)
        train_correct += (predicted == labels).sum().item()

    # 在訓練集上計算準確率
    train_accuracy = 100.0 * train_correct / len(train_dataset)

    # 在測試集上進行評估
    vit_model.eval()
    test_loss = 0.0
    test_correct = 0

    with torch.no_grad():
        bar = tqdm(test_loader,total=len(test_loader))
        for images, labels in bar:
            images = images.to(device)
            labels = labels.to(device)

            outputs = vit_model(images)
            loss = criterion(outputs.logits, labels)

            test_loss += loss.item()
            _, predicted = torch.max(outputs.logits, 1)
            test_correct += (predicted == labels).sum().item()

    # 在測試集上計算準確率
    test_accuracy = 100.0 * test_correct / len(test_dataset)

    # 打印每個epoch的訓練損失、訓練準確率和測試準確率
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Accuracy: {test_accuracy:.2f}%')


torch.save(vit_model.state_dict(), 'vit_model_parameters.pth')

# 打印BERT模型的權重維度
for name, param in vit_model.named_parameters():
    print(name, param.data.shape)

# # # 保存模型參數爲NumPy格式
model_params = {name: param.data.cpu().numpy() for name, param in vit_model.named_parameters()}
np.savez('vit_model_params.npz', **model_params)
# model_params

 

Epoch [1/1], Train Loss: 97.7498, Train Accuracy: 96.21%, Test Accuracy: 96.86%

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章