刚接触大模型论文时,看到满屏的“CheckPoint”是不是瞬间头大?别慌!其实它就像游戏里的自动存档——关键时刻能救你“命”,还能让模型“越练越聪明”。今天用最通俗的话,带你拆解CheckPoint(检查点)如何实现模型“训练存档”。
一、概念解读
二、技术实现
PyTorch如何实现CheckPoint(检查点)?PyTorch使用torch.save
和torch.load
手动保存/加载模型状态字典(state_dict
)。
checkpoint_epoch1_loss0.5.pth
)。import torch
import os
# 定义模型和优化器
model = Model()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 训练循环
best_loss = float('inf') # 记录最佳损失值
for epoch in range(100):
model.train()
total_loss = 0.0
# 模拟训练步骤
for batch in dataloader:
inputs, labels = batch
optimizer.zero_grad()
outputs = model(inputs)
loss = torch.nn.functional.cross_entropy(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
# 保存条件判断(轮数或损失值)
save_flag = False
if (epoch + 1) % 10 == 0: # 每10轮保存一次
save_flag = True
elif avg_loss <= 0.5: # 损失≤0.5时保存
save_flag = True
if save_flag:
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': avg_loss
}
save_path = f"checkpoint_epoch{epoch+1}_loss{avg_loss:.2f}.pth"
torch.save(checkpoint, save_path)
print(f"Checkpoint saved: {save_path}")