先複製一點知乎上的內容
按照上面的流程圖,一個ViT block可以分爲以下幾個步驟
(1) patch embedding:例如輸入圖片大小爲224x224,將圖片分爲固定大小的patch,patch大小爲16x16,則每張圖像會生成224x224/16x16=196個patch,即輸入序列長度爲196,每個patch維度16x16x3=768,線性投射層的維度爲768xN (N=768),因此輸入通過線性投射層之後的維度依然爲196x768,即一共有196個token,每個token的維度是768。這裏還需要加上一個特殊字符cls,因此最終的維度是197x768。到目前爲止,已經通過patch embedding將一個視覺問題轉化爲了一個seq2seq問題
(2) positional encoding(standard learnable 1D position embeddings):ViT同樣需要加入位置編碼,位置編碼可以理解爲一張表,表一共有N行,N的大小和輸入序列長度相同,每一行代表一個向量,向量的維度和輸入序列embedding的維度相同(768)。注意位置編碼的操作是sum,而不是concat。加入位置編碼信息之後,維度依然是197x768
(3) LN/multi-head attention/LN:LN輸出維度依然是197x768。多頭自注意力時,先將輸入映射到q,k,v,如果只有一個頭,qkv的維度都是197x768,如果有12個頭(768/12=64),則qkv的維度是197x64,一共有12組qkv,最後再將12組qkv的輸出拼接起來,輸出維度是197x768,然後在過一層LN,維度依然是197x768
(4) MLP:將維度放大再縮小回去,197x768放大爲197x3072,再縮小變爲197x768
一個block之後維度依然和輸入相同,都是197x768,因此可以堆疊多個block。最後會將特殊字符cls對應的輸出 Z0 作爲encoder的最終輸出 ,代表最終的image presentation(另一種做法是不加cls字符,對所有的tokens的輸出做一個平均),如下圖公式(4),後面接一個MLP進行圖片分類
vit 的 numpy 實現代碼,可以直接看懂各個部分的細節實現 ,和bert有一些不一樣,除了embedding層不一樣之外,還有模型結構有有些不同,主要是layer_normalization放在了attention層和feed_forword層之前,bert都是放在之後
import numpy as np import os from PIL import Image # 加載保存的模型數據 model_data = np.load('vit_model_params.npz') for i in model_data: # print(i) print(i,model_data[i].shape) patch_embedding_weight = model_data["vit.embeddings.patch_embeddings.projection.weight"] patch_embedding_bias = model_data["vit.embeddings.patch_embeddings.projection.bias"] position_embeddings = model_data["vit.embeddings.position_embeddings"] cls_token_embeddings = model_data["vit.embeddings.cls_token"] def patch_embedding(images): # 卷積核大小 kernel_size = 16 return conv2d(images, patch_embedding_weight,patch_embedding_bias,stride=kernel_size) def position_embedding(): return position_embeddings def model_input(images): patch_embedded = np.transpose(patch_embedding(images).reshape([1,768,-1]), (0, 2, 1)) patch_embedded = np.concatenate([cls_token_embeddings,patch_embedded],axis=1) # position_ids = np.array(range(patch_embedded.shape[1])) # 位置id # 位置嵌入矩陣,形狀爲 (max_position, embedding_size) position_embedded = position_embedding() embedding_output = patch_embedded + position_embedded return embedding_output def softmax(x, axis=None): # e_x = np.exp(x).astype(np.float32) # e_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) sum_ex = np.sum(e_x, axis=axis,keepdims=True).astype(np.float32) return e_x / sum_ex def conv2d(images,weight,bias,stride=1,padding=0): # 卷積操作 N, C, H, W = images.shape F, _, HH, WW = weight.shape # 計算卷積後的輸出尺寸 H_out = (H - HH + 2 * padding) // stride + 1 W_out = (W - WW + 2 * padding) // stride + 1 # 初始化卷積層輸出 out = np.zeros((N, F, H_out, W_out)) # 執行卷積運算 for i in range(H_out): for j in range(W_out): # 提取當前卷積窗口 window = images[:, :, i * stride:i * stride + HH, j * stride:j * stride + WW] # 執行卷積運算 out[:, :, i, j] = np.sum(window * weight, axis=(1, 2, 3)) + bias # 輸出結果 # print("卷積層輸出尺寸:", out.shape) return out def scaled_dot_product_attention(Q, K, V, mask=None): d_k = Q.shape[-1] scores = np.matmul(Q, K.transpose(0, 2, 1)) / np.sqrt(d_k) if mask is not None: scores = np.where(mask, scores, np.full_like(scores, -np.inf)) attention_weights = softmax(scores, axis=-1) # print(attention_weights) # print(np.sum(attention_weights,axis=-1)) output = np.matmul(attention_weights, V) return output, attention_weights def multihead_attention(input, num_heads,W_Q,B_Q,W_K,B_K,W_V,B_V,W_O,B_O): q = np.matmul(input, W_Q.T)+B_Q k = np.matmul(input, W_K.T)+B_K v = np.matmul(input, W_V.T)+B_V # 分割輸入爲多個頭 q = np.split(q, num_heads, axis=-1) k = np.split(k, num_heads, axis=-1) v = np.split(v, num_heads, axis=-1) outputs = [] for q_,k_,v_ in zip(q,k,v): output, attention_weights = scaled_dot_product_attention(q_, k_, v_) outputs.append(output) outputs = np.concatenate(outputs, axis=-1) outputs = np.matmul(outputs, W_O.T)+B_O return outputs def layer_normalization(x, weight, bias, eps=1e-12): mean = np.mean(x, axis=-1, keepdims=True) variance = np.var(x, axis=-1, keepdims=True) std = np.sqrt(variance + eps) normalized_x = (x - mean) / std output = weight * normalized_x + bias return output def feed_forward_layer(inputs, weight, bias, activation='relu'): linear_output = np.matmul(inputs,weight) + bias if activation == 'relu': activated_output = np.maximum(0, linear_output) # ReLU激活函數 elif activation == 'gelu': activated_output = 0.5 * linear_output * (1 + np.tanh(np.sqrt(2 / np.pi) * (linear_output + 0.044715 * np.power(linear_output, 3)))) # GELU激活函數 elif activation == "tanh" : activated_output = np.tanh(linear_output) else: activated_output = linear_output # 無激活函數 return activated_output def residual_connection(inputs, residual): # 殘差連接 residual_output = inputs + residual return residual_output def vit(input,num_heads=12): for i in range(12): # 調用多頭自注意力函數 W_Q = model_data['vit.encoder.layer.{}.attention.attention.query.weight'.format(i)] B_Q = model_data['vit.encoder.layer.{}.attention.attention.query.bias'.format(i)] W_K = model_data['vit.encoder.layer.{}.attention.attention.key.weight'.format(i)] B_K = model_data['vit.encoder.layer.{}.attention.attention.key.bias'.format(i)] W_V = model_data['vit.encoder.layer.{}.attention.attention.value.weight'.format(i)] B_V = model_data['vit.encoder.layer.{}.attention.attention.value.bias'.format(i)] W_O = model_data['vit.encoder.layer.{}.attention.output.dense.weight'.format(i)] B_O = model_data['vit.encoder.layer.{}.attention.output.dense.bias'.format(i)] intermediate_weight = model_data['vit.encoder.layer.{}.intermediate.dense.weight'.format(i)] intermediate_bias = model_data['vit.encoder.layer.{}.intermediate.dense.bias'.format(i)] dense_weight = model_data['vit.encoder.layer.{}.output.dense.weight'.format(i)] dense_bias = model_data['vit.encoder.layer.{}.output.dense.bias'.format(i)] LayerNorm_before_weight = model_data['vit.encoder.layer.{}.layernorm_before.weight'.format(i)] LayerNorm_before_bias = model_data['vit.encoder.layer.{}.layernorm_before.bias'.format(i)] LayerNorm_after_weight = model_data['vit.encoder.layer.{}.layernorm_after.weight'.format(i)] LayerNorm_after_bias = model_data['vit.encoder.layer.{}.layernorm_after.bias'.format(i)] output = layer_normalization(input,LayerNorm_before_weight,LayerNorm_before_bias) output = multihead_attention(output, num_heads,W_Q,B_Q,W_K,B_K,W_V,B_V,W_O,B_O) output1 = residual_connection(input,output) #這裏和模型輸出一致 output = layer_normalization(output1,LayerNorm_after_weight,LayerNorm_after_bias) #一致 output = feed_forward_layer(output, intermediate_weight.T, intermediate_bias, activation='gelu') output = feed_forward_layer(output, dense_weight.T, dense_bias, activation='') output2 = residual_connection(output1,output) input = output2 bert_pooler_dense_weight = model_data['vit.layernorm.weight'] bert_pooler_dense_bias = model_data['vit.layernorm.bias'] output = layer_normalization(output2[:,0],bert_pooler_dense_weight,bert_pooler_dense_bias ) #一致 classifier_weight = model_data['classifier.weight'] classifier_bias = model_data['classifier.bias'] output = feed_forward_layer(output,classifier_weight.T,classifier_bias,activation="" ) #一致 output = np.argmax(output,axis=-1) return output folder_path = './cifar10' # 替換爲圖片所在的文件夾路徑 def infer_images_in_folder(folder_path): for file_name in os.listdir(folder_path): file_path = os.path.join(folder_path, file_name) if os.path.isfile(file_path) and file_name.endswith(('.jpg', '.jpeg', '.png')): image = Image.open(file_path) image = image.resize((224, 224)) label = file_name.split(".")[0].split("_")[1] image = np.array(image)/255.0 image = np.transpose(image, (2, 0, 1)) image = np.expand_dims(image,axis=0) print("file_path:",file_path,"img size:",image.shape,"label:",label) input = model_input(image) predicted_class = vit(input) print('Predicted class:', predicted_class) if __name__ == "__main__": infer_images_in_folder(folder_path)
結果:
file_path: ./cifar10/8619_5.jpg img size: (1, 3, 224, 224) label: 5 Predicted class: [5] file_path: ./cifar10/6042_6.jpg img size: (1, 3, 224, 224) label: 6 Predicted class: [6] file_path: ./cifar10/6801_6.jpg img size: (1, 3, 224, 224) label: 6 Predicted class: [6] file_path: ./cifar10/7946_1.jpg img size: (1, 3, 224, 224) label: 1 Predicted class: [1] file_path: ./cifar10/6925_2.jpg img size: (1, 3, 224, 224) label: 2 Predicted class: [2] file_path: ./cifar10/6007_6.jpg img size: (1, 3, 224, 224) label: 6 Predicted class: [6] file_path: ./cifar10/7903_1.jpg img size: (1, 3, 224, 224) label: 1 Predicted class: [1] file_path: ./cifar10/7064_5.jpg img size: (1, 3, 224, 224) label: 5 Predicted class: [5] file_path: ./cifar10/2713_8.jpg img size: (1, 3, 224, 224) label: 8 Predicted class: [8] file_path: ./cifar10/8575_9.jpg img size: (1, 3, 224, 224) label: 9 Predicted class: [9] file_path: ./cifar10/1985_6.jpg img size: (1, 3, 224, 224) label: 6 Predicted class: [6] file_path: ./cifar10/5312_5.jpg img size: (1, 3, 224, 224) label: 5 Predicted class: [5] file_path: ./cifar10/593_6.jpg img size: (1, 3, 224, 224) label: 6 Predicted class: [6] file_path: ./cifar10/8093_7.jpg img size: (1, 3, 224, 224) label: 7 Predicted class: [7] file_path: ./cifar10/6862_5.jpg img size: (1, 3, 224, 224) label: 5
模型參數:
vit.embeddings.cls_token (1, 1, 768) vit.embeddings.position_embeddings (1, 197, 768) vit.embeddings.patch_embeddings.projection.weight (768, 3, 16, 16) vit.embeddings.patch_embeddings.projection.bias (768,) vit.encoder.layer.0.attention.attention.query.weight (768, 768) vit.encoder.layer.0.attention.attention.query.bias (768,) vit.encoder.layer.0.attention.attention.key.weight (768, 768) vit.encoder.layer.0.attention.attention.key.bias (768,) vit.encoder.layer.0.attention.attention.value.weight (768, 768) vit.encoder.layer.0.attention.attention.value.bias (768,) vit.encoder.layer.0.attention.output.dense.weight (768, 768) vit.encoder.layer.0.attention.output.dense.bias (768,) vit.encoder.layer.0.intermediate.dense.weight (3072, 768) vit.encoder.layer.0.intermediate.dense.bias (3072,) vit.encoder.layer.0.output.dense.weight (768, 3072) vit.encoder.layer.0.output.dense.bias (768,) vit.encoder.layer.0.layernorm_before.weight (768,) vit.encoder.layer.0.layernorm_before.bias (768,) vit.encoder.layer.0.layernorm_after.weight (768,) vit.encoder.layer.0.layernorm_after.bias (768,) vit.encoder.layer.1.attention.attention.query.weight (768, 768) vit.encoder.layer.1.attention.attention.query.bias (768,) vit.encoder.layer.1.attention.attention.key.weight (768, 768) vit.encoder.layer.1.attention.attention.key.bias (768,) vit.encoder.layer.1.attention.attention.value.weight (768, 768) vit.encoder.layer.1.attention.attention.value.bias (768,) vit.encoder.layer.1.attention.output.dense.weight (768, 768) vit.encoder.layer.1.attention.output.dense.bias (768,) vit.encoder.layer.1.intermediate.dense.weight (3072, 768) vit.encoder.layer.1.intermediate.dense.bias (3072,) vit.encoder.layer.1.output.dense.weight (768, 3072) vit.encoder.layer.1.output.dense.bias (768,) vit.encoder.layer.1.layernorm_before.weight (768,) vit.encoder.layer.1.layernorm_before.bias (768,) vit.encoder.layer.1.layernorm_after.weight (768,) vit.encoder.layer.1.layernorm_after.bias (768,) vit.encoder.layer.2.attention.attention.query.weight (768, 768) vit.encoder.layer.2.attention.attention.query.bias (768,) vit.encoder.layer.2.attention.attention.key.weight (768, 768) vit.encoder.layer.2.attention.attention.key.bias (768,) vit.encoder.layer.2.attention.attention.value.weight (768, 768) vit.encoder.layer.2.attention.attention.value.bias (768,) vit.encoder.layer.2.attention.output.dense.weight (768, 768) vit.encoder.layer.2.attention.output.dense.bias (768,) vit.encoder.layer.2.intermediate.dense.weight (3072, 768) vit.encoder.layer.2.intermediate.dense.bias (3072,) vit.encoder.layer.2.output.dense.weight (768, 3072) vit.encoder.layer.2.output.dense.bias (768,) vit.encoder.layer.2.layernorm_before.weight (768,) vit.encoder.layer.2.layernorm_before.bias (768,) vit.encoder.layer.2.layernorm_after.weight (768,) vit.encoder.layer.2.layernorm_after.bias (768,) vit.encoder.layer.3.attention.attention.query.weight (768, 768) vit.encoder.layer.3.attention.attention.query.bias (768,) vit.encoder.layer.3.attention.attention.key.weight (768, 768) vit.encoder.layer.3.attention.attention.key.bias (768,) vit.encoder.layer.3.attention.attention.value.weight (768, 768) vit.encoder.layer.3.attention.attention.value.bias (768,) vit.encoder.layer.3.attention.output.dense.weight (768, 768) vit.encoder.layer.3.attention.output.dense.bias (768,) vit.encoder.layer.3.intermediate.dense.weight (3072, 768) vit.encoder.layer.3.intermediate.dense.bias (3072,) vit.encoder.layer.3.output.dense.weight (768, 3072) vit.encoder.layer.3.output.dense.bias (768,) vit.encoder.layer.3.layernorm_before.weight (768,) vit.encoder.layer.3.layernorm_before.bias (768,) vit.encoder.layer.3.layernorm_after.weight (768,) vit.encoder.layer.3.layernorm_after.bias (768,) vit.encoder.layer.4.attention.attention.query.weight (768, 768) vit.encoder.layer.4.attention.attention.query.bias (768,) vit.encoder.layer.4.attention.attention.key.weight (768, 768) vit.encoder.layer.4.attention.attention.key.bias (768,) vit.encoder.layer.4.attention.attention.value.weight (768, 768) vit.encoder.layer.4.attention.attention.value.bias (768,) vit.encoder.layer.4.attention.output.dense.weight (768, 768) vit.encoder.layer.4.attention.output.dense.bias (768,) vit.encoder.layer.4.intermediate.dense.weight (3072, 768) vit.encoder.layer.4.intermediate.dense.bias (3072,) vit.encoder.layer.4.output.dense.weight (768, 3072) vit.encoder.layer.4.output.dense.bias (768,) vit.encoder.layer.4.layernorm_before.weight (768,) vit.encoder.layer.4.layernorm_before.bias (768,) vit.encoder.layer.4.layernorm_after.weight (768,) vit.encoder.layer.4.layernorm_after.bias (768,) vit.encoder.layer.5.attention.attention.query.weight (768, 768) vit.encoder.layer.5.attention.attention.query.bias (768,) vit.encoder.layer.5.attention.attention.key.weight (768, 768) vit.encoder.layer.5.attention.attention.key.bias (768,) vit.encoder.layer.5.attention.attention.value.weight (768, 768) vit.encoder.layer.5.attention.attention.value.bias (768,) vit.encoder.layer.5.attention.output.dense.weight (768, 768) vit.encoder.layer.5.attention.output.dense.bias (768,) vit.encoder.layer.5.intermediate.dense.weight (3072, 768) vit.encoder.layer.5.intermediate.dense.bias (3072,) vit.encoder.layer.5.output.dense.weight (768, 3072) vit.encoder.layer.5.output.dense.bias (768,) vit.encoder.layer.5.layernorm_before.weight (768,) vit.encoder.layer.5.layernorm_before.bias (768,) vit.encoder.layer.5.layernorm_after.weight (768,) vit.encoder.layer.5.layernorm_after.bias (768,) vit.encoder.layer.6.attention.attention.query.weight (768, 768) vit.encoder.layer.6.attention.attention.query.bias (768,) vit.encoder.layer.6.attention.attention.key.weight (768, 768) vit.encoder.layer.6.attention.attention.key.bias (768,) vit.encoder.layer.6.attention.attention.value.weight (768, 768) vit.encoder.layer.6.attention.attention.value.bias (768,) vit.encoder.layer.6.attention.output.dense.weight (768, 768) vit.encoder.layer.6.attention.output.dense.bias (768,) vit.encoder.layer.6.intermediate.dense.weight (3072, 768) vit.encoder.layer.6.intermediate.dense.bias (3072,) vit.encoder.layer.6.output.dense.weight (768, 3072) vit.encoder.layer.6.output.dense.bias (768,) vit.encoder.layer.6.layernorm_before.weight (768,) vit.encoder.layer.6.layernorm_before.bias (768,) vit.encoder.layer.6.layernorm_after.weight (768,) vit.encoder.layer.6.layernorm_after.bias (768,) vit.encoder.layer.7.attention.attention.query.weight (768, 768) vit.encoder.layer.7.attention.attention.query.bias (768,) vit.encoder.layer.7.attention.attention.key.weight (768, 768) vit.encoder.layer.7.attention.attention.key.bias (768,) vit.encoder.layer.7.attention.attention.value.weight (768, 768) vit.encoder.layer.7.attention.attention.value.bias (768,) vit.encoder.layer.7.attention.output.dense.weight (768, 768) vit.encoder.layer.7.attention.output.dense.bias (768,) vit.encoder.layer.7.intermediate.dense.weight (3072, 768) vit.encoder.layer.7.intermediate.dense.bias (3072,) vit.encoder.layer.7.output.dense.weight (768, 3072) vit.encoder.layer.7.output.dense.bias (768,) vit.encoder.layer.7.layernorm_before.weight (768,) vit.encoder.layer.7.layernorm_before.bias (768,) vit.encoder.layer.7.layernorm_after.weight (768,) vit.encoder.layer.7.layernorm_after.bias (768,) vit.encoder.layer.8.attention.attention.query.weight (768, 768) vit.encoder.layer.8.attention.attention.query.bias (768,) vit.encoder.layer.8.attention.attention.key.weight (768, 768) vit.encoder.layer.8.attention.attention.key.bias (768,) vit.encoder.layer.8.attention.attention.value.weight (768, 768) vit.encoder.layer.8.attention.attention.value.bias (768,) vit.encoder.layer.8.attention.output.dense.weight (768, 768) vit.encoder.layer.8.attention.output.dense.bias (768,) vit.encoder.layer.8.intermediate.dense.weight (3072, 768) vit.encoder.layer.8.intermediate.dense.bias (3072,) vit.encoder.layer.8.output.dense.weight (768, 3072) vit.encoder.layer.8.output.dense.bias (768,) vit.encoder.layer.8.layernorm_before.weight (768,) vit.encoder.layer.8.layernorm_before.bias (768,) vit.encoder.layer.8.layernorm_after.weight (768,) vit.encoder.layer.8.layernorm_after.bias (768,) vit.encoder.layer.9.attention.attention.query.weight (768, 768) vit.encoder.layer.9.attention.attention.query.bias (768,) vit.encoder.layer.9.attention.attention.key.weight (768, 768) vit.encoder.layer.9.attention.attention.key.bias (768,) vit.encoder.layer.9.attention.attention.value.weight (768, 768) vit.encoder.layer.9.attention.attention.value.bias (768,) vit.encoder.layer.9.attention.output.dense.weight (768, 768) vit.encoder.layer.9.attention.output.dense.bias (768,) vit.encoder.layer.9.intermediate.dense.weight (3072, 768) vit.encoder.layer.9.intermediate.dense.bias (3072,) vit.encoder.layer.9.output.dense.weight (768, 3072) vit.encoder.layer.9.output.dense.bias (768,) vit.encoder.layer.9.layernorm_before.weight (768,) vit.encoder.layer.9.layernorm_before.bias (768,) vit.encoder.layer.9.layernorm_after.weight (768,) vit.encoder.layer.9.layernorm_after.bias (768,) vit.encoder.layer.10.attention.attention.query.weight (768, 768) vit.encoder.layer.10.attention.attention.query.bias (768,) vit.encoder.layer.10.attention.attention.key.weight (768, 768) vit.encoder.layer.10.attention.attention.key.bias (768,) vit.encoder.layer.10.attention.attention.value.weight (768, 768) vit.encoder.layer.10.attention.attention.value.bias (768,) vit.encoder.layer.10.attention.output.dense.weight (768, 768) vit.encoder.layer.10.attention.output.dense.bias (768,) vit.encoder.layer.10.intermediate.dense.weight (3072, 768) vit.encoder.layer.10.intermediate.dense.bias (3072,) vit.encoder.layer.10.output.dense.weight (768, 3072) vit.encoder.layer.10.output.dense.bias (768,) vit.encoder.layer.10.layernorm_before.weight (768,) vit.encoder.layer.10.layernorm_before.bias (768,) vit.encoder.layer.10.layernorm_after.weight (768,) vit.encoder.layer.10.layernorm_after.bias (768,) vit.encoder.layer.11.attention.attention.query.weight (768, 768) vit.encoder.layer.11.attention.attention.query.bias (768,) vit.encoder.layer.11.attention.attention.key.weight (768, 768) vit.encoder.layer.11.attention.attention.key.bias (768,) vit.encoder.layer.11.attention.attention.value.weight (768, 768) vit.encoder.layer.11.attention.attention.value.bias (768,) vit.encoder.layer.11.attention.output.dense.weight (768, 768) vit.encoder.layer.11.attention.output.dense.bias (768,) vit.encoder.layer.11.intermediate.dense.weight (3072, 768) vit.encoder.layer.11.intermediate.dense.bias (3072,) vit.encoder.layer.11.output.dense.weight (768, 3072) vit.encoder.layer.11.output.dense.bias (768,) vit.encoder.layer.11.layernorm_before.weight (768,) vit.encoder.layer.11.layernorm_before.bias (768,) vit.encoder.layer.11.layernorm_after.weight (768,) vit.encoder.layer.11.layernorm_after.bias (768,) vit.layernorm.weight (768,) vit.layernorm.bias (768,) classifier.weight (10, 768) classifier.bias (10,)
hungging face模型訓練代碼 對cifar10訓練,保存模型參數爲numpy格式,方便numpy的模型加載:
import torch import torch.nn as nn import torch.optim as optim import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10 from transformers import ViTModel, ViTForImageClassification from tqdm import tqdm import numpy as np # 設置隨機種子 torch.manual_seed(42) # 定義超參數 batch_size = 64 num_epochs = 1 learning_rate = 1e-4 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 數據預處理 transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) # 加載CIFAR-10數據集 train_dataset = CIFAR10(root='/data/xinyuuliu/datas', train=True, download=True, transform=transform) test_dataset = CIFAR10(root='/data/xinyuuliu/datas', train=False, download=True, transform=transform) # 創建數據加載器 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # 加載預訓練的ViT模型 vit_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device) # 替換分類頭 num_classes = 10 # vit_model.config.classifier = 'mlp' # vit_model.config.num_labels = num_classes vit_model.classifier = nn.Linear(vit_model.config.hidden_size, num_classes).to(device) # parameters = list(vit_model.parameters()) # for x in parameters[:-1]: # x.requires_grad = False # 定義損失函數和優化器 criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(vit_model.parameters(), lr=learning_rate) # 微調ViT模型 for epoch in range(num_epochs): print("epoch:",epoch) vit_model.train() train_loss = 0.0 train_correct = 0 bar = tqdm(train_loader,total=len(train_loader)) for images, labels in bar: images = images.to(device) labels = labels.to(device) # 前向傳播 outputs = vit_model(images) loss = criterion(outputs.logits, labels) # 反向傳播和優化 optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() _, predicted = torch.max(outputs.logits, 1) train_correct += (predicted == labels).sum().item() # 在訓練集上計算準確率 train_accuracy = 100.0 * train_correct / len(train_dataset) # 在測試集上進行評估 vit_model.eval() test_loss = 0.0 test_correct = 0 with torch.no_grad(): bar = tqdm(test_loader,total=len(test_loader)) for images, labels in bar: images = images.to(device) labels = labels.to(device) outputs = vit_model(images) loss = criterion(outputs.logits, labels) test_loss += loss.item() _, predicted = torch.max(outputs.logits, 1) test_correct += (predicted == labels).sum().item() # 在測試集上計算準確率 test_accuracy = 100.0 * test_correct / len(test_dataset) # 打印每個epoch的訓練損失、訓練準確率和測試準確率 print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Accuracy: {test_accuracy:.2f}%') torch.save(vit_model.state_dict(), 'vit_model_parameters.pth') # 打印BERT模型的權重維度 for name, param in vit_model.named_parameters(): print(name, param.data.shape) # # # 保存模型參數爲NumPy格式 model_params = {name: param.data.cpu().numpy() for name, param in vit_model.named_parameters()} np.savez('vit_model_params.npz', **model_params) # model_params
Epoch [1/1], Train Loss: 97.7498, Train Accuracy: 96.21%, Test Accuracy: 96.86%