Pytorch入门手写数字

#导入需要的包 import numpy as np import torch from torch import nn from PIL import Image import torchvision import matplotlib.pyplot as plt import os from torchvision import datasets, transforms,utils Step1:准备数据。 transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5],std=[0.5])]) train_data = datasets.MNIST(root = "./data/", transform=transform, train = True, download = True) test_data = datasets.MNIST(root="./data/", transform = transform, train = False) print(len(train_data)) print(len(test_data)) 60000 10000 train_data 的个数:60000个训练样本 test_data 的个数:10000个训练样本 train_loader = torch.utils.data.DataLoader(train_data,batch_size=128, shuffle=True,num_workers=2) test_loader = torch.utils.data.DataLoader(test_data,batch_size=128, shuffle=True,num_workers=2) print(len(train_loader)) print(len(test_loader)) 469 79 加载到dataloader中后,一个dataloader是一个batch的数据...

April 9, 2021 · wuyangzz