Hello Computer Vision

argmax의 keepdim = True, False 본문

딥러닝/파이토치

argmax의 keepdim = True, False

지웅쓰 2023. 5. 1. 11:30

코드를 공부하는데 argmax의 keepdim = True라는 것을 보았다.

이전부터 보았던 건데 왜 공부안했는지 싶어서 남겨두려고 한다.

 

일단 직관적으로 유지할 수 있다는 것을 알 수 있다.

arr = torch.rand(2, 3, 4)
pred = arr.argmax(2, keepdim = False)
print(pred)
print(pred.shape)

tensor([[0, 2, 1],
        [0, 2, 1]])
torch.Size([2, 3])

크기가 (2, 3, 4)인 텐서에 argmax를 취해주면 가장 높은 값을 기준으로 주어진 인자에 따라 차원 한개가 사라지는 것을 볼 수 있다.

 

arr = torch.rand(2, 3, 4)
pred = arr.argmax(2, keepdim = True)
print(pred)
print(pred.shape)

tensor([[[1],
         [3],
         [0]],

        [[3],
         [3],
         [1]]])
torch.Size([2, 3, 1])

그러나 만약 keepdim = True로 설정한다면 차원이 사라지지 않고 그대로 유지되는 것을 볼 수 있다.