模型和权重文件管理之PyTorch

模型和权重

在机器学习和深度学习中,保存训练结果通常涉及以下两大类:

  1. 检查点 (Checkpoints): 检查点是在训练过程中定期保存的模型状态快照。它们包含了模型的权重和参数、优化器状态、学习率调度器状态等,以便可以从中断的地方恢复训练或者用于评估和推理。检查点对于长时间运行的训练任务尤其重要,因为它们可以防止因系统故障或其他原因导致的训练进度丢失。
  2. 最终模型文件: 在训练完成后,通常会保存一个最终的模型文件,这个文件包含了训练完成后的模型权重和参数。这个最终模型通常用于部署、推理或进一步的研究。在某些情况下,如果训练过程中的某个检查点在验证集上表现得最好,那么这个检查点也可以被视为最终模型。

在 PyTorch 中,模型和权重文件通常以 .pth.pt 格式保存。在 PyTorch 中,管理和使用模型及其权重文件是一个关键的技能,尤其是在进行训练后保存、加载模型、或者迁移学习时。

1. 定义模型

首先,你需要定义一个 PyTorch 模型。这通常包括继承 torch.nn.Module 类并实现 forward 方法。

import torch
import torch.nn as nn

class SimpleModel(nn.Module):
   def __init__(self):
       super(SimpleModel, self).__init__()
       self.fc = nn.Linear(10, 2)
       
   def forward(self, x):
       return self.fc(x)

2. 训练模型

在训练模型时,你通常会进行正向传播、计算损失、反向传播和优化。训练完成后,你需要保存模型的权重。

import torch.optim as optim

# 初始化模型和优化器
model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 假设你已经训练了模型
# ...

# 保存模型的权重
torch.save(model.state_dict(), 'model_weights.pth')

3. 保存模型权重

在 PyTorch 中,最常用的方法是保存模型的 权重 而非整个模型对象。保存权重是为了确保在不同的环境下,模型的结构和训练过程一致。

torch.save(model.state_dict(), 'model_weights.pth')
  • state_dict():这是一个包含模型所有参数(如权重、偏置等)的字典。
  • 保存文件:一般情况下保存 .pth.pt 后缀的文件。

4. 加载模型权重

当你想恢复训练好的模型时,可以使用 load_state_dict 来加载模型的权重。

4.1 加载权重到相同结构的模型

model = SimpleModel()
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()  # 将模型设置为评估模式
  • torch.load():加载保存的权重文件。
  • model.load_state_dict():将加载的权重应用到模型。
  • model.eval():切换模型到评估模式,禁用一些只在训练时启用的操作(如 dropout)。

4.2 加载权重到不同结构的模型

如果你修改了模型的结构(例如,修改了层数或神经网络层),你可以选择性地加载某些层的权重。这个时候,你可以使用 strict=False 参数来忽略不匹配的层。

model = SimpleModel()
checkpoint = torch.load('model_weights.pth')
model.load_state_dict(checkpoint, strict=False)

5. 保存和加载整个模型

虽然更常用的是只保存模型的权重,但也可以选择保存整个模型,包括模型的结构和权重。这样做的缺点是对模型的代码依赖性较强,并且不推荐用于生产环境。你可以使用 torch.save()torch.load() 来保存和加载整个模型。

5.1 保存整个模型

torch.save(model, 'model.pth')

5.2 加载整个模型

model = torch.load('model.pth')
model.eval()  # 设置为评估模式

6. 使用预训练模型

PyTorch 提供了一些常用的预训练模型(例如 ResNet、VGG 等),可以通过 torchvision.models 模块轻松加载这些模型并进行迁移学习。

import torchvision.models as models

# 加载一个预训练的 ResNet-50 模型
model = models.resnet50(pretrained=True)
model.eval()  # 切换到评估模式

7. 保存和加载优化器的状态

除了模型权重,有时我们还需要保存和恢复优化器的状态。这样可以从上次的训练进度继续训练。

7.1 保存优化器状态

torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, 'checkpoint.pth')

7.2 加载优化器状态

checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval() # 如果在评估模式下

8. 模型转换

如果你需要将训练好的模型导出到其他框架(如 TensorFlow 或 ONNX),PyTorch 也提供了相应的转换工具。

8.1 转换为 ONNX 格式

import torch.onnx

# 假设你有一个输入样本 input_tensor
torch.onnx.export(model, input_tensor, "model.onnx")

这样,你可以将模型保存为 .onnx 格式,并在其他框架中使用。

小结

  • 保存模型权重:使用 torch.save(model.state_dict(), 'model_weights.pth')
  • 加载模型权重:使用 model.load_state_dict(torch.load('model_weights.pth'))
  • 保存整个模型:torch.save(model, 'model.pth')
  • 加载整个模型:model = torch.load('model.pth')
  • 保存和加载优化器状态:通过 optimizer.state_dict() 保存和恢复优化器状态。

这应该涵盖了大部分你在管理和使用 PyTorch 模型和权重文件时可能遇到的常见操作。

发表评论