본문 바로가기
A.I./PyTorch

PyTorch 문서) PyTorch Recipes - Warmstarting model using parameters from a different model in PyTorch

by 채소장사 2024. 10. 11.

다른 모델의 매개변수를 사용해, 모델을 초기화하기

  • 원글 : Warmstarting model using parameters from a different model in PyTorch
  • 전이학습(transfer learning)을 하거나 새롭게 복잡한 모델을 학습할 때, 모델을 처음부터 학습하는 대신(from scratch)
    • 기존 모델의 일부 매개변수를 가져오거나,
    • 전체 모델을 구성하는 부분 모델을 불러와 사용하는 것은 흔한 일이다.
  • 이미 훈련된 매개변수를 활용하는 일은 (비록 매개변수의 일부분만을 사용하게 되더라도)
    • 훈련과정을 (완전 처음부터(coldstart)가 아닌) 기존의 훈련과정 위에서(warmstart) 시작할 수 있게 해주며
    • 처음부터 학습할 때보다 훨씬 빠르게 모델이 수렴하는 것을 기대할 수 있다.

Introduction

  • load_state_dict()를 사용할 때 strict 인자를 False로 두게 되면
    • 기존에 존재하던 state_dict의 모델이 새로운 모델에 비해 몇몇 키 값이 없는 상황이나
    • 새로운 모델보다 더 많은 키를 가지는 경우에도 state_dict를 부분적으로 가져오는 것을 가능하게 해준다.
    • strict=False는 대응되지 않는 키를 무시하도록 해준다.

Setup

pip install torch

Steps

  1. 데이터를 로드하기 위해 필요한 라이브러리들을 불러온다.
  2. 신경망 A와 신경망 B를 정의하고 초기화한다.
  3. 모델 A를 저장한다.
  4. 저장된 모델A의 값을 모델 B로 불러온다.

#단계 1. 데이터를 로드하기 위해 필요한 라이브러리를 불러오기

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

 

#단계 2.

  • 예시를 위해 이미지를 훈련시키는 신경망을 생성한다.
class NetA(nn.Module):
    def __init__(self):
        super(NetA, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

netA = NetA()

class NetB(nn.Module):
    def __init__(self):
        super(NetB, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

netB = NetB()

 

#단계 3. 모델 A를 저장

PATH = "model.pt"

torch.save(netA.state_dict(), PATH)

#단계 4. 저장된 모델 A의 매개변수를 모델 B로 불러오기

  • 한 레이어의 매개변수를 다른 레이어로 불러오려고 할 때, 
    • 몇몇 키가 일치하지 않는다면, 단순히 state_dict의 키의 이름을 바꿔서
    • 저장시킬 모델의 키의 이름과 일치시킬 수 있다.
netB.load_state_dict(torch.load(PATH, weights_only=True), strict=False)

댓글