일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |
Tags
- 백준 알고리즘
- CycleGAN
- simclrv2
- conjugate pseudo label paper
- dcgan
- UnderstandingDeepLearning
- mocov3
- semi supervised learnin 가정
- CoMatch
- 딥러닝손실함수
- shrinkmatch
- remixmatch paper
- ConMatch
- Meta Pseudo Labels
- adamatch paper
- Entropy Minimization
- 최린컴퓨터구조
- shrinkmatch paper
- Pseudo Label
- Pix2Pix
- dann paper
- SSL
- tent paper
- BYOL
- CGAN
- 컴퓨터구조
- mme paper
- WGAN
- cifar100-c
- GAN
Archives
- Today
- Total
Hello Computer Vision
torch detach 실험해보기 본문
Pytorch를 사용하다보면은 detach가 쓰이는 코드를 종종만난다. 사실 이 전까지는 gradient에 대해서 별 생각이 없었다. 그런데 요즘 실험 및 논문을 읽으면서 gradient에 대한 이야기가 많이 나와 조금 정리해보려고 한다.
나는 detach를 사용하면 해당 데이터의 gradient를 사용하지 않는 것으로 알고 있었다.
import torch
import torch.nn as nn
class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(10, 10)
self.layer2 = nn.Linear(10, 10)
self.layer3 = nn.Linear(10, 10)
def forward(self, x):
out1 = self.layer1(x)
out2 = self.layer2(out1)
out3 = self.layer3(out2.detach())
return out3
model = TestModel()
x = torch.randn(1, 10)
a = model(x)
a.mean().backward()
print(model.layer1.weight.grad)
print(model.layer2.weight.grad)
print(model.layer3.weight.grad)
만약 위와 같은 코드가 있다고 할 때 결과값을 어떻게 나올까? layer1, layer2는 detach가 붙지 않은 상태로 layer를 통과했으니 gradient가 저장이 되었을까?
None
None
tensor([[-0.0885, -0.0279, -0.0294, -0.0322, 0.0120, -0.0045, 0.0308, 0.0477,
0.0594, 0.0726],
[-0.0885, -0.0279, -0.0294, -0.0322, 0.0120, -0.0045, 0.0308, 0.0477,
0.0594, 0.0726],
[-0.0885, -0.0279, -0.0294, -0.0322, 0.0120, -0.0045, 0.0308, 0.0477,
0.0594, 0.0726],
[-0.0885, -0.0279, -0.0294, -0.0322, 0.0120, -0.0045, 0.0308, 0.0477,
0.0594, 0.0726],
[-0.0885, -0.0279, -0.0294, -0.0322, 0.0120, -0.0045, 0.0308, 0.0477,
0.0594, 0.0726],
[-0.0885, -0.0279, -0.0294, -0.0322, 0.0120, -0.0045, 0.0308, 0.0477,
0.0594, 0.0726],
[-0.0885, -0.0279, -0.0294, -0.0322, 0.0120, -0.0045, 0.0308, 0.0477,
0.0594, 0.0726],
[-0.0885, -0.0279, -0.0294, -0.0322, 0.0120, -0.0045, 0.0308, 0.0477,
0.0594, 0.0726],
[-0.0885, -0.0279, -0.0294, -0.0322, 0.0120, -0.0045, 0.0308, 0.0477,
0.0594, 0.0726],
[-0.0885, -0.0279, -0.0294, -0.0322, 0.0120, -0.0045, 0.0308, 0.0477,
0.0594, 0.0726]])
정답은 layer3의 input에 대하여 detach를 수행한다면 그 전까지의 gradient를 다 삭제함을 알 수 있다.
++ 2024.03.01 수정)
detach를 번역된 공식 문서를 살펴보면 새로운 tensor를 생성하는데 기존 텐서에서 연관된 gradient는 지워지는 tensor 를 재생성한다고 볼 수 있다. 위의 예시를 보면 out2는 원래 layer1, layer2 함수들을 거쳐 out2의 발자취가 남아있지만 out3를 생성하는 과정에서 out2의 발자취를 없애고 새로운 out2(연결 그래프 없이 깔끔한 텐서)를 layer3 함수에 적용되었기 때문에 layer3 에만 gradient가 남아있는 것을 확인할 수 있는 것이다.
'딥러닝 > 파이토치' 카테고리의 다른 글
torch stack, cat 차이 (0) | 2024.01.17 |
---|---|
함수 이용해 이미지 rotate 하기 (0) | 2024.01.17 |
torch SummaryWriter 에 관한 자그만한 정보 (0) | 2024.01.05 |
[pytorch] 처리하고자 하는 배치사이즈가 다를 때 해결방법 (1) | 2023.12.29 |
__name__ == '__main__' 쓰는 이유 (0) | 2023.12.21 |