Hello Computer Vision

[pytorch] load state dict 후 업데이트 본문

딥러닝/파이토치

[pytorch] load state dict 후 업데이트

지웅쓰 2024. 1. 22. 13:34

2개의 동일 모델이 있고 한개는 이미 학습이 어느 정도 됐고 다른 1개는 초기화 상태이다. 학습이 어느 정도 된 모델의 파라미터를 안된 모델의 가중치로 넘겨주고, 이 모델에 대해서 학습을 시킨다.

그렇다면 기존 넘겨준 모델의 가중치는 업데이트가 될까? 직관적으로는 안될 거 같지만 실험을 해봤다.

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


x_train = torch.FloatTensor([[1], [2], [3]])
y_train = torch.FloatTensor([[2], [4], [6]])

layer = nn.Linear(1, 1)
layerc = nn.Linear(1, 1)

param = layer.state_dict()
print(param)

layerc.load_state_dict(param)
optimizer = torch.optim.SGD(layerc.parameters(), lr=0.01)
for epoch in range(10):
    prediction = layerc(x_train)

    # cost 계산
    cost = F.mse_loss(prediction, y_train)

    # cost로 H(x) 개선
    optimizer.zero_grad()
    cost.backward()
    optimizer.step()
    print(f'epoch {epoch} 끝')


print(layerc.state_dict())
print(layer.state_dict())

 

OrderedDict([('weight', tensor([[0.7682]])), ('bias', tensor([-0.6635]))])
epoch 0 끝
epoch 1 끝
epoch 2 끝
epoch 3 끝
epoch 4 끝
epoch 5 끝
epoch 6 끝
epoch 7 끝
epoch 8 끝
epoch 9 끝
OrderedDict([('weight', tensor([[1.6499]])), ('bias', tensor([-0.2728]))])
OrderedDict([('weight', tensor([[0.7682]])), ('bias', tensor([-0.6635]))])

보시다시피 초기 가중치를 넘겨주고 다른 모델에 대해 학습했을 때 기존의 가중치는 변하지 않은 모습이다.

'딥러닝 > 파이토치' 카테고리의 다른 글

[pytorch] 기본적인 tensor 조작해보기  (0) 2024.02.06
[pytorch] Nan, inf알아보기  (0) 2024.02.06
CIFAR10-C 데이터셋 다루기  (0) 2024.01.20
torch stack, cat 차이  (0) 2024.01.17
함수 이용해 이미지 rotate 하기  (0) 2024.01.17