다른 모델의 매개변수를 사용해, 모델을 초기화하기
- 원글 : Warmstarting model using parameters from a different model in PyTorch
- 전이학습(transfer learning)을 하거나 새롭게 복잡한 모델을 학습할 때, 모델을 처음부터 학습하는 대신(from scratch)
- 기존 모델의 일부 매개변수를 가져오거나,
- 전체 모델을 구성하는 부분 모델을 불러와 사용하는 것은 흔한 일이다.
- 이미 훈련된 매개변수를 활용하는 일은 (비록 매개변수의 일부분만을 사용하게 되더라도)
- 훈련과정을 (완전 처음부터(coldstart)가 아닌) 기존의 훈련과정 위에서(warmstart) 시작할 수 있게 해주며
- 처음부터 학습할 때보다 훨씬 빠르게 모델이 수렴하는 것을 기대할 수 있다.
Introduction
- load_state_dict()를 사용할 때 strict 인자를 False로 두게 되면
- 기존에 존재하던 state_dict의 모델이 새로운 모델에 비해 몇몇 키 값이 없는 상황이나
- 새로운 모델보다 더 많은 키를 가지는 경우에도 state_dict를 부분적으로 가져오는 것을 가능하게 해준다.
- strict=False는 대응되지 않는 키를 무시하도록 해준다.
Setup
pip install torch
Steps
- 데이터를 로드하기 위해 필요한 라이브러리들을 불러온다.
- 신경망 A와 신경망 B를 정의하고 초기화한다.
- 모델 A를 저장한다.
- 저장된 모델A의 값을 모델 B로 불러온다.
#단계 1. 데이터를 로드하기 위해 필요한 라이브러리를 불러오기
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
#단계 2.
- 예시를 위해 이미지를 훈련시키는 신경망을 생성한다.
class NetA(nn.Module):
def __init__(self):
super(NetA, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
netA = NetA()
class NetB(nn.Module):
def __init__(self):
super(NetB, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
netB = NetB()
#단계 3. 모델 A를 저장
PATH = "model.pt"
torch.save(netA.state_dict(), PATH)
#단계 4. 저장된 모델 A의 매개변수를 모델 B로 불러오기
- 한 레이어의 매개변수를 다른 레이어로 불러오려고 할 때,
- 몇몇 키가 일치하지 않는다면, 단순히 state_dict의 키의 이름을 바꿔서
- 저장시킬 모델의 키의 이름과 일치시킬 수 있다.
netB.load_state_dict(torch.load(PATH, weights_only=True), strict=False)
댓글