vit 訓練CIFAR10數據集,凍結所有層,只保留全連接層
from transformers import ViTImageProcessor, ViTForImageClassification from PIL import Image import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision.datasets import MNIST,CIFAR10 from torchvision.transforms import ToTensor from torchvision.models import resnet101 from tqdm import tqdm # 設置設備 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #device = torch.device("mps") # torch.device("cpu") # 加載 MNIST 數據集 train_dataset = CIFAR10(root="/data/xinyuuliu/datas", train=True, transform=ToTensor(), download=True) test_dataset = CIFAR10(root="/data/xinyuuliu/datas", train=False, transform=ToTensor()) def collate_fn(batch): """ 對batch數據進行處理 :param batch: [一個getitem的結果,getitem的結果,getitem的結果] :return: 元組 """ reviews,labels = zip(*batch) # print(reviews) # print(labels) # reviews = torch.Tensor(reviews) labels = torch.Tensor(labels) return reviews,labels # 創建數據加載器 train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True,collate_fn=collate_fn) test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False,collate_fn=collate_fn) # url = 'http://images.cocodataset.org/val2017/000000039769.jpg' # image = Image.open(requests.get(url, stream=True).raw) processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') # print(model.get_output_embeddings) # print(model.classifier) model.classifier = nn.Linear(768,10) print(model.classifier) parameters = list(model.parameters()) for x in parameters[:-2]: x.requires_grad = False model.to(device) # 定義損失函數和優化器 criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) def train(model, dataloader, optimizer, criterion): model.train() running_loss = 0.0 for inputs, labels in tqdm(dataloader, desc="Training"): # print(inputs) inputs = processor(images=inputs, return_tensors="pt") inputs['pixel_values'] = inputs['pixel_values'].to(device) labels = labels.to(device) # print(inputs['pixel_values'].shape) # print(labels.shape) optimizer.zero_grad() outputs = model(**inputs) logits = outputs.logits # print(logits,labels) loss = criterion(logits, labels.long()) loss.backward() optimizer.step() # model predicts one of the 1000 ImageNet classes # predicted_class_idx = logits.argmax(-1).item() # print("Predicted class:", model.config.id2label[predicted_class_idx]) running_loss += loss.item() * inputs['pixel_values'].size(0) epoch_loss = running_loss / len(dataloader.dataset) return epoch_loss def evaluate(model, dataloader): model.eval() correct = 0 total = 0 with torch.no_grad(): for inputs, labels in tqdm(dataloader, desc="Evaluating"): inputs = processor(images=inputs, return_tensors="pt") inputs['pixel_values'] = inputs['pixel_values'].to(device) labels = labels.to(device) outputs = model(**inputs) logits = outputs.logits predicted= logits.argmax(-1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = correct / total * 100 return accuracy # 訓練和評估 num_epochs = 10 for epoch in range(num_epochs): print(f"Epoch {epoch+1}/{num_epochs}") train_loss = train(model, train_loader, optimizer, criterion) print(f"Training Loss: {train_loss:.4f}") test_acc = evaluate(model, test_loader) print(f"Test Accuracy: {test_acc:.2f}%")