본문 바로가기
A.I./PyTorch

PyTorch 문서) PyTorch Recipes - What is a state_dict in PyTorch

by 채소장사 2024. 10. 8.

state_dict의 의미

  • 원글 : What is a state_dict in PyTorch?
  • torch.nn.Module을 통해 구현된 파이토치 모델의 학습가능한 매개변수들
    • model.parameters()로 접근할 수 있는 모델의 매개변수(=파라미터, parameter) 안에 포함되어 있다.
    • 참고) 학습가능한 매개변수(learnable parameters)에는 가중치(weight)편향(bias) 값이 있다.
  • 이 때, state_dict는 모델의 레이어들과 각각의 매개변수 텐서를 대응시켜주는 파이썬 딕셔너리 객체다.

Introduction

  • state_dict는 PyTorch에서 모델을 저장하거나 불러올 때, 사용할 수 있다.
  • 파이썬 딕셔너리 객체이기 때문에 쉽게 저장, 갱신, 변경될 수 있다. 
    • 파이토치의 모델과 옵티마이저에 모듈성(modularity)을 제공한다.
    • 파이토치 모델의 state_dict에는
      • 학습 가능한 파라미터를 갖는 레이어(예, 컨볼루션 레이어, 선형 레이어 등)와
      • 배치정규화(batchnorm)층의 이동 평균(running_mean)과 같은 registered buffer가 저장되고
    • 옵티마이저 객체의 state_dict는
      • 옵티마이저의 상태(state) 정보와
      • 훈련과정에서 설정된 하이퍼 파라미터(초 매개변수, hyperparameter) 정보를 갖고 있다.
  • 이 글에서는 간단한 모델을 통해 state_dict가 사용되는 방법을 살펴본다.

Setup

  • 시작하기 전에 torch를 설치할 필요가 있다.
pip install torch

Steps

  1. 데이터를 로드하기 위해 필요한 라이브러리를 불러온다.
  2. 신경망을 정의하고 초기화한다.
  3. 옵티마이저를 초기화한다.
  4. 모델과 옵티마이저의 state_dict에 접근한다.

#단계 1.  데이터를 로드하기 위해, 필요한 라이브러리 불러오기

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

 

#단계 2. 신경망 정의와 초기화

  • 여기서는 예시를 위해, 이미지를 학습하는 (컨볼루션) 신경망을 생성한다.
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5, 120)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
        
net = Net()

 

#단계 3. 옵티마이저 초기화

  • 여기서는 모멘텀(momentum)을 갖는 SGD 옵티마이저를 사용한다.
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

 

#단계 4. 모델과 옵티마이저의 state_dict에 접근

  • 생성된 모델과 옵티마이저를 통해 각각의 state_dict 특성 안에 저장된 정보를 확인할 수 있다.
print("Model's state_dict::")
for param_tensor in net.state_dict():
    print(param_tensor, "\t", net.state_dict()[param_tensor].size())
    
print()

print("Optimizer's state_dict::")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

  • state_dict에 저장된 정보는 모델과 옵티마이저의 저장 및 로드에 관계있고, 나중에 재사용될 수 있다.
  • 참고) 위에서 옵티마이저는 훈련되지 않았기 때문에, 상태 정보가 존재하지 않았다.

댓글