state_dict의 의미
- 원글 : What is a state_dict in PyTorch?
- torch.nn.Module을 통해 구현된 파이토치 모델의 학습가능한 매개변수들은
- model.parameters()로 접근할 수 있는 모델의 매개변수(=파라미터, parameter) 안에 포함되어 있다.
- 참고) 학습가능한 매개변수(learnable parameters)에는 가중치(weight)와 편향(bias) 값이 있다.
- 이 때, state_dict는 모델의 레이어들과 각각의 매개변수 텐서를 대응시켜주는 파이썬 딕셔너리 객체다.
Introduction
- state_dict는 PyTorch에서 모델을 저장하거나 불러올 때, 사용할 수 있다.
- 파이썬 딕셔너리 객체이기 때문에 쉽게 저장, 갱신, 변경될 수 있다.
- 파이토치의 모델과 옵티마이저에 모듈성(modularity)을 제공한다.
- 파이토치 모델의 state_dict에는
- 학습 가능한 파라미터를 갖는 레이어(예, 컨볼루션 레이어, 선형 레이어 등)와
- 배치정규화(batchnorm)층의 이동 평균(running_mean)과 같은 registered buffer가 저장되고
- 옵티마이저 객체의 state_dict는
- 옵티마이저의 상태(state) 정보와
- 훈련과정에서 설정된 하이퍼 파라미터(초 매개변수, hyperparameter) 정보를 갖고 있다.
- 이 글에서는 간단한 모델을 통해 state_dict가 사용되는 방법을 살펴본다.
Setup
- 시작하기 전에 torch를 설치할 필요가 있다.
pip install torch
Steps
- 데이터를 로드하기 위해 필요한 라이브러리를 불러온다.
- 신경망을 정의하고 초기화한다.
- 옵티마이저를 초기화한다.
- 모델과 옵티마이저의 state_dict에 접근한다.
#단계 1. 데이터를 로드하기 위해, 필요한 라이브러리 불러오기
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
#단계 2. 신경망 정의와 초기화
- 여기서는 예시를 위해, 이미지를 학습하는 (컨볼루션) 신경망을 생성한다.
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5, 120)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
#단계 3. 옵티마이저 초기화
- 여기서는 모멘텀(momentum)을 갖는 SGD 옵티마이저를 사용한다.
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
#단계 4. 모델과 옵티마이저의 state_dict에 접근
- 생성된 모델과 옵티마이저를 통해 각각의 state_dict 특성 안에 저장된 정보를 확인할 수 있다.
print("Model's state_dict::")
for param_tensor in net.state_dict():
print(param_tensor, "\t", net.state_dict()[param_tensor].size())
print()
print("Optimizer's state_dict::")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
- state_dict에 저장된 정보는 모델과 옵티마이저의 저장 및 로드에 관계있고, 나중에 재사용될 수 있다.
- 참고) 위에서 옵티마이저는 훈련되지 않았기 때문에, 상태 정보가 존재하지 않았다.
댓글