추론을 위한 모델의 저장과 로드
- 원글 : Saving and Loading models for inference in PyTorch
- 파이토치에서 모델을 저장하고 불러와 사용하는 방법에는 2가지가 있다.
- 첫번째는 state_dict를 저장 및 로드하는 것이고
- 두번째는 전체 모델을 저장하거나 불러오는 방법이다.
Introduction
- torch.save() 함수를 통해 모델의 state_dict를 저장하는 것은 나중에 모델을 다시 사용하기 위한 가장 유연한 방법이다.
- 이 방식이 추천되는 이유는 훈련된 모델에서 학습된 파라미터만을 저장하기 때문이다.
- 전체 모델의 저장과 로드는 모델의 전체 모듈을 파이썬의 pickle 모듈을 사용해 저장하는 방법이다.
- 참고) pickle은 파이썬 객체를 저장하고 불러오는데 사용되는 모듈로서, 직렬화(serialization) 및 역-직렬화(de-serialization)를 위하여 객체를 바이너리 형태로 변환하는 것과 관련된다.
- 전체 모델의 저장/로드에 관한 방법은
- 직관적인 문법과 최소한의 코드를 사용하는 장점이 있지만,
- 직렬화된 데이터, 즉 저장된 파일이 모델이 저장될 때 사용된 특정 클래스나 디렉토리 구조에 종속된다는 것이 단점이다.
- 이는 pickle이 클래스 자체를 저장하는 것이 아니라, 클래스를 포함한 파일의 경로를 저장하기 때문이며
- 이 때문에 리팩토링을 거친 다른 프로젝트에서 전체 모델이 저장된 pickle파일을 사용할 경우, 동작하지 않을 수 있는 위험이 있다.
Setup
- 시작하기 전에 torch를 설치할 필요가 있다.
pip install torch
Setup
- 데이터를 로드하기 위해 필요한 라이브러리를 불러온다.
- 신경망을 정의하고 초기화한다.
- 옵티마이저를 초기화한다.
- state_dict를 통해 모델을 저장하고 로드한다.
- 전체 모델을 저장하고 로드한다.
#단계 1. 데이터를 로드하기 위해, 필요한 라이브러리 불러오기
import torch
import torch.nn as nn
import torch.optim as optim
#단계 2. 모델의 정의와 초기화
- 여기서는 예시를 위해, 이미지를 학습하는 (컨볼루션) 신경망을 생성한다.
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
#단계 3. 옵티마이저 초기화
- 여기서는 모멘텀(momentum)을 갖는 SGD 옵티마이저를 사용한다.
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
#단계 4. state_dict를 통해 모델을 저장하고 로드하기
- 다음은 state_dict를 통한 모델의 저장과 로드 방식이다.
PATH = "state_dict_model.pt"
# 저장
torch.save(net.state_dict(), PATH)
# 저장된 모델 불러오기
model = Net()
model.load_state_dict(torch.load(PATH, weights_only=True))
# 추론을 위한 모드 전환
model.eval()
- 파이토치에서는 통상
.pt
나.pth
확장자를 이용하여 모델을 저장한다. - load_state_dict()는 저장된 파일의 경로를 바로 사용하는 것이 아니라, 딕셔너리 객체를 입력으로 받는다.
- 따라서, 저장된 state_dict를 먼저 역-직렬화(de-serialize)하고나서 load_state_dict()에 넘겨줄 수 있다.
- 참고) 파일에서 저장된 객체를 불러오는 torch.load 문서를 참고할 수 있다. 이 문서에 따르면 weights_only 인자는 역직렬화를 수행하는 unpickler가 오직 텐서, 함수 타입, 딕셔너리 등을 로드하도록 제한할 수 있다고 한다. 여기서는 state_dict를 통해 저장한 딕셔러니 형태의 각 레이어 가중치를 불러오도록 weights_only를 사용하였다고 보인다.
- 마지막에 쓰인 model.eval()은 추론(inference)에 모델을 사용하기 위하여 호출되었다.
- 이는 추론 연산 전에 드롭아웃(dropout)이나 배치 정규화(batch normalization) 층이 있는 경우 평가 모드(evaluation mode)로 전환하기 위해서이다.
- 몇몇 레이어들은 훈련과 평가/추론 단계에서 그 동작이 다르다.
- model.eval()을 선언하지 않으면, 추론의 결과가 일관적이지 않을 수 있다.
- 이는 추론 연산 전에 드롭아웃(dropout)이나 배치 정규화(batch normalization) 층이 있는 경우 평가 모드(evaluation mode)로 전환하기 위해서이다.
#단계 5. 전체 모델을 저장하고 로드하기
PATH = "entire_model.pt"
# 저장
torch.save(net, PATH)
# 불러오기
model = torch.load(PATH)
# 추론을 위한 모드 전환
model.eval()
댓글