vit 训练CIFAR10数据集,冻结所有层,只保留全连接层
from transformers import ViTImageProcessor, ViTForImageClassification from PIL import Image import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision.datasets import MNIST,CIFAR10 from torchvision.transforms import ToTensor from torchvision.models import resnet101 from tqdm import tqdm # 设置设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #device = torch.device("mps") # torch.device("cpu") # 加载 MNIST 数据集 train_dataset = CIFAR10(root="/data/xinyuuliu/datas", train=True, transform=ToTensor(), download=True) test_dataset = CIFAR10(root="/data/xinyuuliu/datas", train=False, transform=ToTensor()) def collate_fn(batch): """ 对batch数据进行处理 :param batch: [一个getitem的结果,getitem的结果,getitem的结果] :return: 元组 """ reviews,labels = zip(*batch) # print(reviews) # print(labels) # reviews = torch.Tensor(reviews) labels = torch.Tensor(labels) return reviews,labels # 创建数据加载器 train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True,collate_fn=collate_fn) test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False,collate_fn=collate_fn) # url = 'http://images.cocodataset.org/val2017/000000039769.jpg' # image = Image.open(requests.get(url, stream=True).raw) processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') # print(model.get_output_embeddings) # print(model.classifier) model.classifier = nn.Linear(768,10) print(model.classifier) parameters = list(model.parameters()) for x in parameters[:-2]: x.requires_grad = False model.to(device) # 定义损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) def train(model, dataloader, optimizer, criterion): model.train() running_loss = 0.0 for inputs, labels in tqdm(dataloader, desc="Training"): # print(inputs) inputs = processor(images=inputs, return_tensors="pt") inputs['pixel_values'] = inputs['pixel_values'].to(device) labels = labels.to(device) # print(inputs['pixel_values'].shape) # print(labels.shape) optimizer.zero_grad() outputs = model(**inputs) logits = outputs.logits # print(logits,labels) loss = criterion(logits, labels.long()) loss.backward() optimizer.step() # model predicts one of the 1000 ImageNet classes # predicted_class_idx = logits.argmax(-1).item() # print("Predicted class:", model.config.id2label[predicted_class_idx]) running_loss += loss.item() * inputs['pixel_values'].size(0) epoch_loss = running_loss / len(dataloader.dataset) return epoch_loss def evaluate(model, dataloader): model.eval() correct = 0 total = 0 with torch.no_grad(): for inputs, labels in tqdm(dataloader, desc="Evaluating"): inputs = processor(images=inputs, return_tensors="pt") inputs['pixel_values'] = inputs['pixel_values'].to(device) labels = labels.to(device) outputs = model(**inputs) logits = outputs.logits predicted= logits.argmax(-1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = correct / total * 100 return accuracy # 训练和评估 num_epochs = 10 for epoch in range(num_epochs): print(f"Epoch {epoch+1}/{num_epochs}") train_loss = train(model, train_loader, optimizer, criterion) print(f"Training Loss: {train_loss:.4f}") test_acc = evaluate(model, test_loader) print(f"Test Accuracy: {test_acc:.2f}%")