Hello Computer Vision

[pytorch] transpose, view 의 차이 알아보기 본문

딥러닝/파이토치

[pytorch] transpose, view 의 차이 알아보기

지웅쓰 2023. 3. 16. 15:43

경량화 모델 중 하나인 ShuffleNet의 코드를 살펴보던 중 view, transpose를 사용하는 코드들을 발견하였는데 제대롤 이해해야 나중에 잘 활용할 수 있겠다 생각해 알아보았다.

 

내가 직접 짜지않고 보고 이해하는 것은 할 수 있었으나 실제로 어떤 기능으로 활용되는지는 잘 알 수 없었다.

 

view 예시

우리가 네트워크를 짜는 도중에 input값을 맞출 때 자주 사용된다.

x = torch.rand(2,3,4)
print(x)

tensor([[[0.1796, 0.9432, 0.8833, 0.7571],
         [0.9830, 0.5509, 0.6370, 0.2307],
         [0.9010, 0.7599, 0.1854, 0.5038]],

        [[0.8308, 0.9059, 0.2282, 0.1945],
         [0.0027, 0.0368, 0.9602, 0.8460],
         [0.0226, 0.6881, 0.7628, 0.4756]]])
y = x.view(2, 4, 3)
print(y)
print(y.shape)

tensor([[[0.1796, 0.9432, 0.8833],
         [0.7571, 0.9830, 0.5509],
         [0.6370, 0.2307, 0.9010],
         [0.7599, 0.1854, 0.5038]],

        [[0.8308, 0.9059, 0.2282],
         [0.1945, 0.0027, 0.0368],
         [0.9602, 0.8460, 0.0226],
         [0.6881, 0.7628, 0.4756]]])
torch.Size([2, 4, 3])

(2, 3, 4) 크기의 텐서가 view를 활용하여 (2, 4, 3) 으로 변환되었음을 확인할 수 있다.

 

transpose 예시

보통 우리가 사용하는 텐서의 크기는 (batch, channel, h, w) 이다

그러나 모종의 이유로 h,w 의 위치를 바꾸고 싶을 때가 있다. 그럴 때 transpose가 사용된다.

z = x.transpose(1,2)
print(z)
print(z.shape)

tensor([[[0.1796, 0.9830, 0.9010],
         [0.9432, 0.5509, 0.7599],
         [0.8833, 0.6370, 0.1854],
         [0.7571, 0.2307, 0.5038]],

        [[0.8308, 0.0027, 0.0226],
         [0.9059, 0.0368, 0.6881],
         [0.2282, 0.9602, 0.7628],
         [0.1945, 0.8460, 0.4756]]])
torch.Size([2, 4, 3])

 

방법에서 차이가 있지만 텐서를 다룰 수 있는 유용한 함수들이다.

추가로 이러한 차이점을 제외하고도 contiguous 부분에서 차이가 생기는데 contiguous란 메모리 부분에서의 연속성의 문제이다. 코드를 통해 한번 알아보자.

for i in range(2):
    for j in range(4):
        for k in range(3):
            print(y[i][j][k].data_ptr())
            
1758990581888
1758990581892
1758990581896
1758990581900
1758990581904
1758990581908
1758990581912
1758990581916
1758990581920
1758990581924
1758990581928
1758990581932
1758990581936
1758990581940
1758990581944
1758990581948
1758990581952
1758990581956
1758990581960
1758990581964
1758990581968
1758990581972
1758990581976
1758990581980

view 결과 값인 y원소들의 메모리 주소를 살펴보면 4씩 증가함을 알 수 있다.

 

for i in range(2):
    for j in range(4):
        for k in range(3):
            print(z[i][j][k].data_ptr())
            
1758990581888
1758990581904
1758990581920
1758990581892
1758990581908
1758990581924
1758990581896
1758990581912
1758990581928
1758990581900
1758990581916
1758990581932
1758990581936
1758990581952
1758990581968
1758990581940
1758990581956
1758990581972
1758990581944
1758990581960
1758990581976
1758990581948
1758990581964
1758990581980

그러나 transpose한 결과값인 z의 메모리주소들을 살펴보면 이와 다름을 확인할 수 있다.

 

print(y.is_contiguous())
print(z.is_contiguous())

True
False

이러한 부분에서 차이가 생깁니다. 만약 contiguous 의 부분에서 오류가 생긴다면 

 

z = z.contiguous()
print(z.is_contiguous())

True

이러한 변환을 통해 연속적으로 만들 수 있습니다.


 

물론 view / reshape 의 차이와 transpose / permute 의 차이도 존재함을 알고 있지만 아직은 이에 대해 문제점이 발생하지 않았으므로 나중에 나오면 공부해보겠습니다.