模型和权重
在机器学习和深度学习中,保存训练结果通常涉及以下两大类:
- 检查点 (Checkpoints): 检查点是在训练过程中定期保存的模型状态快照。它们包含了模型的权重和参数、优化器状态、学习率调度器状态等,以便可以从中断的地方恢复训练或者用于评估和推理。检查点对于长时间运行的训练任务尤其重要,因为它们可以防止因系统故障或其他原因导致的训练进度丢失。
- 最终模型文件: 在训练完成后,通常会保存一个最终的模型文件,这个文件包含了训练完成后的模型权重和参数。这个最终模型通常用于部署、推理或进一步的研究。在某些情况下,如果训练过程中的某个检查点在验证集上表现得最好,那么这个检查点也可以被视为最终模型。
在 PyTorch 中,模型和权重文件通常以 .pth
或 .pt
格式保存。在 PyTorch 中,管理和使用模型及其权重文件是一个关键的技能,尤其是在进行训练后保存、加载模型、或者迁移学习时。
1. 定义模型
首先,你需要定义一个 PyTorch 模型。这通常包括继承 torch.nn.Module
类并实现 forward
方法。
import torch
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 2)
def forward(self, x):
return self.fc(x)
2. 训练模型
在训练模型时,你通常会进行正向传播、计算损失、反向传播和优化。训练完成后,你需要保存模型的权重。
import torch.optim as optim
# 初始化模型和优化器
model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 假设你已经训练了模型
# ...
# 保存模型的权重
torch.save(model.state_dict(), 'model_weights.pth')
3. 保存模型权重
在 PyTorch 中,最常用的方法是保存模型的 权重 而非整个模型对象。保存权重是为了确保在不同的环境下,模型的结构和训练过程一致。
torch.save(model.state_dict(), 'model_weights.pth')
state_dict()
:这是一个包含模型所有参数(如权重、偏置等)的字典。- 保存文件:一般情况下保存
.pth
或.pt
后缀的文件。
4. 加载模型权重
当你想恢复训练好的模型时,可以使用 load_state_dict
来加载模型的权重。
4.1 加载权重到相同结构的模型
model = SimpleModel()
model.load_state_dict(torch.load('model_weights.pth'))
model.eval() # 将模型设置为评估模式
torch.load()
:加载保存的权重文件。model.load_state_dict()
:将加载的权重应用到模型。model.eval()
:切换模型到评估模式,禁用一些只在训练时启用的操作(如 dropout)。
4.2 加载权重到不同结构的模型
如果你修改了模型的结构(例如,修改了层数或神经网络层),你可以选择性地加载某些层的权重。这个时候,你可以使用 strict=False
参数来忽略不匹配的层。
model = SimpleModel()
checkpoint = torch.load('model_weights.pth')
model.load_state_dict(checkpoint, strict=False)
5. 保存和加载整个模型
虽然更常用的是只保存模型的权重,但也可以选择保存整个模型,包括模型的结构和权重。这样做的缺点是对模型的代码依赖性较强,并且不推荐用于生产环境。你可以使用 torch.save()
和 torch.load()
来保存和加载整个模型。
5.1 保存整个模型
torch.save(model, 'model.pth')
5.2 加载整个模型
model = torch.load('model.pth')
model.eval() # 设置为评估模式
6. 使用预训练模型
PyTorch 提供了一些常用的预训练模型(例如 ResNet、VGG 等),可以通过 torchvision.models
模块轻松加载这些模型并进行迁移学习。
import torchvision.models as models
# 加载一个预训练的 ResNet-50 模型
model = models.resnet50(pretrained=True)
model.eval() # 切换到评估模式
7. 保存和加载优化器的状态
除了模型权重,有时我们还需要保存和恢复优化器的状态。这样可以从上次的训练进度继续训练。
7.1 保存优化器状态
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, 'checkpoint.pth')
7.2 加载优化器状态
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval() # 如果在评估模式下
8. 模型转换
如果你需要将训练好的模型导出到其他框架(如 TensorFlow 或 ONNX),PyTorch 也提供了相应的转换工具。
8.1 转换为 ONNX 格式
import torch.onnx
# 假设你有一个输入样本 input_tensor
torch.onnx.export(model, input_tensor, "model.onnx")
这样,你可以将模型保存为 .onnx
格式,并在其他框架中使用。
小结
- 保存模型权重:使用
torch.save(model.state_dict(), 'model_weights.pth')
。 - 加载模型权重:使用
model.load_state_dict(torch.load('model_weights.pth'))
。 - 保存整个模型:
torch.save(model, 'model.pth')
。 - 加载整个模型:
model = torch.load('model.pth')
。 - 保存和加载优化器状态:通过
optimizer.state_dict()
保存和恢复优化器状态。
这应该涵盖了大部分你在管理和使用 PyTorch 模型和权重文件时可能遇到的常见操作。