Hello Computer Vision

torch.topk 함수 공부해보기 본문

딥러닝

torch.topk 함수 공부해보기

지웅쓰 2023. 8. 22. 17:08

논문에서 결과를 비교할 때 top1, top5의 결과를 비교할 때가 많은데 이때 사용하는 함수가 torch.topk함수이다.

https://pytorch.org/docs/stable/generated/torch.topk.html

 

torch.topk — PyTorch 2.0 documentation

Shortcuts

pytorch.org

예시를 통해 한번 알아보자

x = torch.arange(1, 10)
value, pred = torch.topk(x, 3)
print(x)
print(pred)
print(value)
tensor([1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([8, 7, 6])
tensor([9, 8, 7])

상위 3개의 값을 뽑아봤다. 추가적인 파라미터로는 설정으로 작은 k개의 값을 뽑을 수도 있으며 정렬 여부도 결정할 수 있다. 기본값은 다 큰 k개의 값이며 정렬을 해주는 것이 기본 값이다.

 

그리고 보통 torch.topk 이렇게 사용하는 것이 아닌 output.topk 이렇게 사용하는 경우도 있는데 다음 예시와 같다.

x = torch.rand(3,5)
print(x)
print('-' * 30)
value, pred = x.topk(3, 1, True, True)
print(value)
print('-' * 30)
print(pred)
tensor([[0.5453, 0.6736, 0.9698, 0.0100, 0.4278],
        [0.2417, 0.4822, 0.1055, 0.4558, 0.8578],
        [0.7234, 0.5069, 0.9546, 0.7297, 0.5313]])
------------------------------
tensor([[0.9698, 0.6736, 0.5453],
        [0.8578, 0.4822, 0.4558],
        [0.9546, 0.7297, 0.7234]])
------------------------------
tensor([[2, 1, 0],
        [4, 1, 3],
        [2, 3, 0]])

첫번째 파라미터는 상위 k개의 값, 두번째 파라미터는 dimension이다.