일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |
- CGAN
- 컴퓨터구조
- Entropy Minimization
- mme paper
- 최린컴퓨터구조
- shrinkmatch paper
- SSL
- conjugate pseudo label paper
- tent paper
- mocov3
- 딥러닝손실함수
- 백준 알고리즘
- cifar100-c
- remixmatch paper
- CycleGAN
- Pix2Pix
- CoMatch
- BYOL
- UnderstandingDeepLearning
- dcgan
- Meta Pseudo Labels
- WGAN
- GAN
- dann paper
- adamatch paper
- simclrv2
- semi supervised learnin 가정
- ConMatch
- shrinkmatch
- Pseudo Label
- Today
- Total
Hello Computer Vision
torch.triplet margin distance loss 살펴보기 본문
이번에 SSL공부하면서 loss를 좀 제대로 정의해야겠다고 생각해서 파이토치 문서에서 찾아보고 한번 뜯어볼라고한다.
https://pytorch.org/docs/stable/generated/torch.nn.TripletMarginLoss.html
우선 triplet margin loss와 triplet margin distance loss가 있던데 지금 공부하는건 후자이다.
import torch
def triplet_margin_with_distance_loss(
anchor:Tensor,
positive:Tensor,
negative:Tensor,
distance_fucntion:Optional = None,
margin :float = 1.0,
swap:bool = False
reduction = 'mean'
)
a_dim = anchor.ndim
p_dim = positive.ndim
n_dim = negative.ndim
if not (a_dim == p_dim and p_dim == n_dim):
raise RuntimeError
if distance_function is None:
distance_function = torch.pairwise_distance
dist_pos = distance_function(anchor, positive)
dist_neg = distance_function(anchor, negative)
if swap:
dist_swap = distance_function(positive, negative)
dist_neg = torch.minimum(dist_neg, dist_swap)
loss = torch.clamp_min(margin + dist_pos - dist_neg, 0)
if reduction == 'sum':
return torch.sum(loss)
elif reduction == 'mean':
return torch.mean(loss)
else:
return loss
코드를 살펴보면 다음과 같다. 이건 파이토치 문서를 참고해 내가 작성한 것이고 문서에는 추가적으로 주석이나 코드 호환성 관련하여 있다.
위에서부터 뜯어보자!
if distance_function is None:
distance_function = torch.pairwise_distance
우선 거리함수를 따로 정의하지 않으면 자체로 pairwise distance 로 거리함수로 정의하는 것을 알 수 있다.
https://pytorch.org/docs/stable/generated/torch.nn.PairwiseDistance.html
pairwise 문서는 다음과 같고 기본값은 l2 distance 인 것을 확인할 수 있다.
dist_pos = distance_function(anchor, positive)
dist_neg = distance_function(anchor, negative)
if swap:
dist_swap = distance_function(positive, negative)
dist_neg = torch.minimum(dist_neg, dist_swap)
기준이 되는 anchor와 positive, negative sample간의 거리를 구한다. 여기서 swap 이라는 하이퍼파라미터에 True 값을 준다면 기존의 negative 거리를 (positive, negative)와 비교하여 더 작은 것을 넣어주는 것을 확인할 수 있다.
loss = torch.clamp_min(margin + dist_pos - dist_neg, 0)
loss는 margin 값을 더하여 구하는데 여기서 loss가 음수로 가지 않게 clamp_min 을 활용해준다.
여기서 음수로 간다 --> positive끼리의 거리는 붙어있고 negative끼리의 거리는 많이 떨어져있다 --> 잘되고 있다
if reduction == 'sum':
return torch.sum(loss)
elif reduction == 'mean':
return torch.mean(loss)
else:
return loss
마지막으로 sum, mean 이 있는데 sum을 인자로 준다면 모든 loss를 합하는 것이고 mean을 준다면 loss에 대해 평균값을내준다. 근데 다른 코드들 보면 보통 mean값을 주더라. 아마 loss가 너무 커서 explode하는 것을 방지하는 것일수도?
--추가
거리를 계산하는 과정에서 데이터셋안에서 데이터를 sampling할 때 위 이미지에서 보여주는 것처럼 이미 거리가 멀리 있는 easy negative 데이터들만 학습할 경우 loss는 낮을 것이고 학습이 제대로 이루어지지 않을 것이다. 따라서 학습을 원할하게 하기 위해서는 적절한 semi hard negative 와 hard negative sample들을 뽑는 것이 중요하다.
틀린점 있다면 지적해주시면 감사하겠습니다.
References
'딥러닝' 카테고리의 다른 글
CAM, Grad-CAM 공부해보기 (0) | 2023.06.30 |
---|---|
GPU연산 DP, DDP 공부해보기 (0) | 2023.06.29 |
Positional Encoding 공부해보기 (0) | 2023.04.04 |
Train/Valid/Test 데이터에 대해 자세히 알아보기 (0) | 2023.04.03 |
Zero shot learning에 대해 공부해보기 (0) | 2023.04.02 |