Hands-on tour to deep learning with pytorch
a first example on Colab : dogs and cats with VGG
Settings
kaggle 连接:
kaggle
数据 :dogs vs cats
Database structure
训练集:
dogscats/train/
验证集:dogscats/valid/
每个目录下含有dogs
和cats
两个目录
Coding Summary
data processing
pytorch
数据处理很多会用到torchvision
,这里主要用到torchvision.Normalize
、torchvision.Normalize
、torchvision.datasets
、
- code 1
dsets = {x:datasets.ImageFolder(os.path.join(data_dir, x), vgg_format) for x in ['train', 'valid']}
- code 2
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225])
vgg_format = transforms.Compose([
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])
标签由文件名给定
load dataset
pytorch
使用torch.utils.data.Dataloader
来导入数据
- code
load_train = torch.utils.data.DataLoader(dsets['train'], batch_size=64, shuffle=True, num_workers=6) load_test = torch.utils.data.DataLoader(dsets['valid'], batch_size=5, shuffle=True, num_workers=6)
`DataLoader`返回的是一个迭代的对象要用`next(iter(load_xx))`来获取数据和标签
Creating VGG Model
torchvision.models.vgg16(pretrained=True/False)
Modifying the last layer and setting the gradient false to all layers
修改vgg16
的classifier
的最后一层,首先将自动求导设置为不求导,然后将最后两层改为需要的
- code
for param in model_vgg.parameters(): param.requires_grad = False model_vgg.classifier._modules['6'] = nn.Linear(4096, 2) model_vgg.classifier._modules['7'] = torch.nn.LogSoftmax(dim=1)
Training the fully module
损失函数和优化器:torch.nn.NLLLoss
和torch.optimizer
criterion = nn.NLLLoss()
lr = 0.001
optimizer_vgg = torch.optim.SGD(model_vgg.classifier[6].parameters(), lr = lr)
训练过程中,首先将数据变成cuda
,然后将数据输入到模型中,调用损失函数和优化器,在loss.backward()
和optimizer.step()
- code
def train_model(model, dataloader, size, epochs=1, optimizer=None): model.train() for epoch in range(epochs): running_loss = 0.0 running_corrects = 0 for inputs, classes in dataloader: inputs = inputs.to(device) classes = classes.to(device) outputs = model(inputs) loss = criterion(outputs, classes) optimizer = optimizer optimizer.zero_grad() loss.backward() optimizer.step() _, preds = torch.max(outputs.data, 1) #statistics running_loss += loss.data.item() running_corrects += torch.sum(preds == classes.data) epoch_loss = running_loss / size epoch_acc = running_corrects.data.item() / size print('Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc))