일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |
Tags
- mme paper
- UnderstandingDeepLearning
- tent paper
- CGAN
- simclrv2
- WGAN
- ConMatch
- Meta Pseudo Labels
- SSL
- 딥러닝손실함수
- conjugate pseudo label paper
- 최린컴퓨터구조
- mocov3
- GAN
- remixmatch paper
- Pseudo Label
- Pix2Pix
- cifar100-c
- CycleGAN
- 컴퓨터구조
- shrinkmatch paper
- BYOL
- dann paper
- Entropy Minimization
- shrinkmatch
- 백준 알고리즘
- adamatch paper
- CoMatch
- semi supervised learnin 가정
- dcgan
Archives
- Today
- Total
Hello Computer Vision
torch.topk 함수 공부해보기 본문
논문에서 결과를 비교할 때 top1, top5의 결과를 비교할 때가 많은데 이때 사용하는 함수가 torch.topk함수이다.
https://pytorch.org/docs/stable/generated/torch.topk.html
예시를 통해 한번 알아보자
x = torch.arange(1, 10)
value, pred = torch.topk(x, 3)
print(x)
print(pred)
print(value)
tensor([1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([8, 7, 6])
tensor([9, 8, 7])
상위 3개의 값을 뽑아봤다. 추가적인 파라미터로는 설정으로 작은 k개의 값을 뽑을 수도 있으며 정렬 여부도 결정할 수 있다. 기본값은 다 큰 k개의 값이며 정렬을 해주는 것이 기본 값이다.
그리고 보통 torch.topk 이렇게 사용하는 것이 아닌 output.topk 이렇게 사용하는 경우도 있는데 다음 예시와 같다.
x = torch.rand(3,5)
print(x)
print('-' * 30)
value, pred = x.topk(3, 1, True, True)
print(value)
print('-' * 30)
print(pred)
tensor([[0.5453, 0.6736, 0.9698, 0.0100, 0.4278],
[0.2417, 0.4822, 0.1055, 0.4558, 0.8578],
[0.7234, 0.5069, 0.9546, 0.7297, 0.5313]])
------------------------------
tensor([[0.9698, 0.6736, 0.5453],
[0.8578, 0.4822, 0.4558],
[0.9546, 0.7297, 0.7234]])
------------------------------
tensor([[2, 1, 0],
[4, 1, 3],
[2, 3, 0]])
첫번째 파라미터는 상위 k개의 값, 두번째 파라미터는 dimension이다.
'딥러닝' 카테고리의 다른 글
Understanding DeepLearning-Supervised learning (0) | 2023.09.19 |
---|---|
비전공생의 SAM(2021) 논문 리뷰 (0) | 2023.09.19 |
np.random.choice 공부해보기 (0) | 2023.08.21 |
Test Time Augmentation 알아보기 (0) | 2023.07.10 |
CAM, Grad-CAM 공부해보기 (0) | 2023.06.30 |