여러 개의 모델을 하나의 파일로 저장해서 활용하기
- 원글 : Saving and Loading multiple models in one file using PyTorch
- 다수의 모델을 하나의 파일로 저장하거나 불러오는 일은 이전에 학습된 모델을 재사용할 때 유용하다.
Introduction
- 적대적 생성 신경망(GAN), 시퀀스-투-시퀀스(Sequence-to-Sequence) 모델이나 앙상블모델(ensemble of models)처럼 여러 개의 torch.nn.Modules로 구성된 모델을 저장할 때는,
- 구성 모델 각각의 state_dict와 대응되는 옵티마이저의 state_dict 들을 딕셔너리에 모두 저장해야 한다.
- 훈련을 재개하는데 도움을 줄 수 있는 다른 정보가 있다면, 단순히 딕셔너리에 추가하여서 손쉽게 활용할 수 있다.
- 여러 모델을 불러올 때는,
- 각 모델과 옴티마이저를 초기화한 뒤에
- torch.load()를 사용해서 각각에 필요한 정보를 딕셔너리에서 로드할 수 있다.
Setup
pip install torch
Steps
- 데이터를 불러오기 위해 필요한 라이브러리들을 로드한다.
- 신경망을 정의하고 초기화한다.
- 옵티마이저를 초기화한다.
- 여러 모델을 저장한다.
- 저장된 여러모델을 불러온다.
#단계 1. 데이터를 불러오기 위해 필요한 라이브러리를 로드
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
#단계 2. 신경망을 정의하고 초기화
- 여기서는 설명을 위해 이미지를 훈련하는 신경망을 정의하고, 저장할 두 개의 모델 객체를 생성한다.
class Net(nn.Modules):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
netA = Net()
netB = Net()
#단계 3. 옵티마이저 초기화
- 두 개의 모델에 각각 쓰일 SGD 옵티마이저를 사용한다.
optimizerA = optim.SGD(netA.parameters(), lr=0.001, momentum=0.9)
optimizerB = optim.SGD(netB.parameters(), lr=0.001, momentum=0.9)
#단계 4. 여러 모델을 함께 저장
- 관련된 정보를 함께 모아서, 딕셔너리로 저장한다.
PATH = "model.pt"
torch.save({
'modelA_state_dict': netA.state_dict(),
'modelB_state_dict': netB.state_dict(),
'optimizerA_state_dict': optimizerA.state_dict(),
'optimizerB_state_dict': optimizerB.state_dict(),
}, PATH)
#단계 5. 여러 모델을 불러오기
- 모델 및 옵티마이저를 먼저 생성 후 초기화하는 단계가 필요하다.
modelA = Net()
modelB = Net()
optimModelA = optim.SGD(modelA.parameters(), lr=0.001, momentum=0.9)
optimModelB = optim.SGD(modelB.parameters(), lr=0.001, momentum=0.9)
checkpoint = torch.load(PATH, weights_only=True)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
# 추론시
modelA.eval()
modelB.eval()
# 훈련 재개시
modelA.train()
modelB.train()
댓글