Hello Computer Vision

torch.triplet margin distance loss 살펴보기 본문

딥러닝

torch.triplet margin distance loss 살펴보기

지웅쓰 2023. 5. 4. 16:46

이번에 SSL공부하면서 loss를 좀 제대로 정의해야겠다고 생각해서 파이토치 문서에서 찾아보고 한번 뜯어볼라고한다.

https://pytorch.org/docs/stable/generated/torch.nn.TripletMarginLoss.html

 

TripletMarginLoss — PyTorch 2.0 documentation

Shortcuts

pytorch.org

 

우선 triplet margin loss와 triplet margin distance loss가 있던데 지금 공부하는건 후자이다.

import torch


def triplet_margin_with_distance_loss(
                        anchor:Tensor,
                        positive:Tensor,
                        negative:Tensor,
                        distance_fucntion:Optional = None,
                        margin :float = 1.0,
                        swap:bool = False
                        reduction = 'mean'
                                      )

    a_dim = anchor.ndim
    p_dim = positive.ndim
    n_dim = negative.ndim

    if not (a_dim == p_dim and p_dim == n_dim):
        raise RuntimeError

    if distance_function is None:
        distance_function = torch.pairwise_distance

    dist_pos = distance_function(anchor, positive)
    dist_neg = distance_function(anchor, negative)

    if swap:
        dist_swap = distance_function(positive, negative)
        dist_neg = torch.minimum(dist_neg, dist_swap)

    loss = torch.clamp_min(margin + dist_pos - dist_neg, 0)

    if reduction == 'sum':
        return torch.sum(loss)
    elif reduction == 'mean':
        return torch.mean(loss)
    else:
        return loss

코드를 살펴보면 다음과 같다. 이건 파이토치 문서를 참고해 내가 작성한 것이고 문서에는 추가적으로 주석이나 코드 호환성 관련하여 있다.

 

위에서부터 뜯어보자!

    if distance_function is None:
        distance_function = torch.pairwise_distance

우선 거리함수를 따로 정의하지 않으면 자체로 pairwise distance 로 거리함수로 정의하는 것을 알 수 있다.

https://pytorch.org/docs/stable/generated/torch.nn.PairwiseDistance.html

 

PairwiseDistance — PyTorch 2.0 documentation

Shortcuts

pytorch.org

pairwise 문서는 다음과 같고 기본값은 l2 distance 인 것을 확인할 수 있다.

 

    dist_pos = distance_function(anchor, positive)
    dist_neg = distance_function(anchor, negative)

    if swap:
        dist_swap = distance_function(positive, negative)
        dist_neg = torch.minimum(dist_neg, dist_swap)

기준이 되는 anchor와 positive, negative sample간의 거리를 구한다. 여기서 swap 이라는 하이퍼파라미터에 True 값을 준다면 기존의 negative 거리를 (positive, negative)와 비교하여 더 작은 것을 넣어주는 것을 확인할 수 있다. 

 

loss = torch.clamp_min(margin + dist_pos - dist_neg, 0)

loss는 margin 값을 더하여 구하는데 여기서 loss가 음수로 가지 않게 clamp_min 을 활용해준다.

여기서 음수로 간다 --> positive끼리의 거리는 붙어있고 negative끼리의 거리는 많이 떨어져있다 --> 잘되고 있다

    if reduction == 'sum':
        return torch.sum(loss)
    elif reduction == 'mean':
        return torch.mean(loss)
    else:
        return loss

마지막으로 sum, mean 이 있는데 sum을 인자로 준다면 모든 loss를 합하는 것이고 mean을 준다면 loss에 대해 평균값을내준다. 근데 다른 코드들 보면 보통 mean값을 주더라. 아마 loss가 너무 커서 explode하는 것을 방지하는 것일수도?

 

 

--추가

거리를 계산하는 과정에서 데이터셋안에서 데이터를 sampling할 때 위 이미지에서 보여주는 것처럼 이미 거리가 멀리 있는 easy negative 데이터들만 학습할 경우 loss는 낮을 것이고 학습이 제대로 이루어지지 않을 것이다. 따라서 학습을 원할하게 하기 위해서는 적절한 semi hard negative 와 hard negative sample들을 뽑는 것이 중요하다.

 

틀린점 있다면 지적해주시면 감사하겠습니다.

 

 


References

https://velog.io/@iissaacc/Triplet-Loss