본문 바로가기
A.I./PyTorch

PyTorch 문서) PyTorch Recipes - Saving and loading multiple models in one file using PyTorch

by 채소장사 2024. 10. 10.

여러 개의 모델을 하나의 파일로 저장해서 활용하기

Introduction

  • 적대적 생성 신경망(GAN), 시퀀스-투-시퀀스(Sequence-to-Sequence) 모델이나 앙상블모델(ensemble of models)처럼 여러 개의 torch.nn.Modules로 구성된 모델을 저장할 때는,
    • 구성 모델 각각의 state_dict와 대응되는 옵티마이저의 state_dict 들을 딕셔너리에 모두 저장해야 한다.
  • 훈련을 재개하는데 도움을 줄 수 있는 다른 정보가 있다면, 단순히 딕셔너리에 추가하여서 손쉽게 활용할 수 있다.
  • 여러 모델을 불러올 때는,
    • 각 모델과 옴티마이저를 초기화한 뒤에
    • torch.load()를 사용해서 각각에 필요한 정보를 딕셔너리에서 로드할 수 있다.

Setup

pip install torch

 

Steps

  1. 데이터를 불러오기 위해 필요한 라이브러리들을 로드한다.
  2. 신경망을 정의하고 초기화한다.
  3. 옵티마이저를 초기화한다.
  4. 여러 모델을 저장한다.
  5. 저장된 여러모델을 불러온다.

#단계 1. 데이터를 불러오기 위해 필요한 라이브러리를 로드

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

 

#단계 2. 신경망을 정의하고 초기화

  • 여기서는 설명을 위해 이미지를 훈련하는 신경망을 정의하고, 저장할 두 개의 모델 객체를 생성한다.
class Net(nn.Modules):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
        
netA = Net()
netB = Net()

 

#단계 3. 옵티마이저 초기화

  • 두 개의 모델에 각각 쓰일 SGD 옵티마이저를 사용한다.
optimizerA = optim.SGD(netA.parameters(), lr=0.001, momentum=0.9)
optimizerB = optim.SGD(netB.parameters(), lr=0.001, momentum=0.9)

 

#단계 4. 여러 모델을 함께 저장

  • 관련된 정보를 함께 모아서, 딕셔너리로 저장한다.
PATH = "model.pt"

torch.save({
        'modelA_state_dict': netA.state_dict(),
        'modelB_state_dict': netB.state_dict(),
        'optimizerA_state_dict': optimizerA.state_dict(),
        'optimizerB_state_dict': optimizerB.state_dict(),
        }, PATH)

 

#단계 5. 여러 모델을 불러오기

  • 모델 및 옵티마이저를 먼저 생성 후 초기화하는 단계가 필요하다.
modelA = Net()
modelB = Net()
optimModelA = optim.SGD(modelA.parameters(), lr=0.001, momentum=0.9)
optimModelB = optim.SGD(modelB.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load(PATH, weights_only=True)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

# 추론시
modelA.eval()
modelB.eval()
# 훈련 재개시
modelA.train()
modelB.train()

댓글