Hello Computer Vision

torch detach 실험해보기 본문

딥러닝/파이토치

torch detach 실험해보기

지웅쓰 2024. 1. 12. 23:00

Pytorch를 사용하다보면은 detach가 쓰이는 코드를 종종만난다. 사실 이 전까지는 gradient에 대해서 별 생각이 없었다. 그런데 요즘 실험 및 논문을 읽으면서 gradient에 대한 이야기가 많이 나와 조금 정리해보려고 한다.

 

나는 detach를 사용하면 해당 데이터의 gradient를 사용하지 않는 것으로 알고 있었다.

import torch
import torch.nn as nn

class TestModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(10, 10)
        self.layer2 = nn.Linear(10, 10)
        self.layer3 = nn.Linear(10, 10)
        
    def forward(self, x):
        out1 = self.layer1(x)
        out2 = self.layer2(out1)
        out3 = self.layer3(out2.detach())
        return out3
    
model = TestModel()
x = torch.randn(1, 10)
a = model(x)
a.mean().backward()
print(model.layer1.weight.grad)
print(model.layer2.weight.grad)
print(model.layer3.weight.grad)

만약 위와 같은 코드가 있다고 할 때 결과값을 어떻게 나올까? layer1, layer2는 detach가 붙지 않은 상태로 layer를 통과했으니 gradient가 저장이 되었을까?

None
None
tensor([[-0.0885, -0.0279, -0.0294, -0.0322,  0.0120, -0.0045,  0.0308,  0.0477,
          0.0594,  0.0726],
        [-0.0885, -0.0279, -0.0294, -0.0322,  0.0120, -0.0045,  0.0308,  0.0477,
          0.0594,  0.0726],
        [-0.0885, -0.0279, -0.0294, -0.0322,  0.0120, -0.0045,  0.0308,  0.0477,
          0.0594,  0.0726],
        [-0.0885, -0.0279, -0.0294, -0.0322,  0.0120, -0.0045,  0.0308,  0.0477,
          0.0594,  0.0726],
        [-0.0885, -0.0279, -0.0294, -0.0322,  0.0120, -0.0045,  0.0308,  0.0477,
          0.0594,  0.0726],
        [-0.0885, -0.0279, -0.0294, -0.0322,  0.0120, -0.0045,  0.0308,  0.0477,
          0.0594,  0.0726],
        [-0.0885, -0.0279, -0.0294, -0.0322,  0.0120, -0.0045,  0.0308,  0.0477,
          0.0594,  0.0726],
        [-0.0885, -0.0279, -0.0294, -0.0322,  0.0120, -0.0045,  0.0308,  0.0477,
          0.0594,  0.0726],
        [-0.0885, -0.0279, -0.0294, -0.0322,  0.0120, -0.0045,  0.0308,  0.0477,
          0.0594,  0.0726],
        [-0.0885, -0.0279, -0.0294, -0.0322,  0.0120, -0.0045,  0.0308,  0.0477,
          0.0594,  0.0726]])

정답은 layer3의 input에 대하여 detach를 수행한다면 그 전까지의 gradient를 다 삭제함을 알 수 있다.

 

++ 2024.03.01 수정) 

detach를 번역된 공식 문서를 살펴보면 새로운 tensor를 생성하는데 기존 텐서에서 연관된 gradient는  지워지는 tensor 를 재생성한다고 볼 수 있다. 위의 예시를 보면 out2는 원래 layer1, layer2 함수들을 거쳐 out2의 발자취가 남아있지만 out3를 생성하는 과정에서 out2의 발자취를 없애고 새로운 out2(연결 그래프 없이 깔끔한 텐서)를 layer3 함수에 적용되었기 때문에 layer3 에만 gradient가 남아있는 것을 확인할 수 있는 것이다.