Hello Computer Vision

비전공생의 Unsupervised Out-of-Distribution Detection by Maximum Classifier Discrepancy 본문

Out of Distribution

비전공생의 Unsupervised Out-of-Distribution Detection by Maximum Classifier Discrepancy

지웅쓰 2023. 9. 29. 22:37

링크: https://arxiv.org/pdf/1908.04951.pdf

OOD를 unsupervised 방식으로 향상시킨 방법이다.

 

Introduction

OOD 에 대한 문제점을 가볍게 언급하고 있다. 그리고 해당 문제점을 새로운 방법인 2 head classifier, unsupervised 방식으로 풀어냈다고 소개하고 있다. 그리고 실험 세팅에서는 unlabeled 데이터에 대해 해당 데이터가 in distribution인지 ood 인지 구분하지 못하지만 성능향상에 도움이 된다고 한다.

 

Method

우선 OOD detection 을 풀어내는 방식의 시초가 되는 MSP논문의 경우 딥러닝 모델의 분류기가 ID의 경우 조금 더 높은 probability를 가지는 것을 이용했다. 이 논문의 저자도 이러한 것을 바탕으로 문제를 해결하려고 한다. 기본적인 구조로는 1개의 feature extractor E와 2개의 classifier f1, f2가 있다. 이렇게 각각의 classfier 에서 나온 softmax probability를 이렇게 표현한다. $p_{1}(y|x), p_{2}(y|x) $ 

 

위 figure를 보면 해당 방법의 motivation을 쉽게 알 수 있다. 2개의 probability에 대하여 L1 distance 를 측정했을 때 ID 데이터에 대해서는 각 분류기의 값들이 차이가 없었지만 OOD데이터에서는 거리가 차이가 있음을 알 수 있다. 

 

Training Procedure

우선 labeled 데이터를 활용하여 network를 훈련한다.

해당 과정을 Pre-Training 으로 논문에서는 표기한다.

 

그리고 Fine Tuning과정에서는 labeled, unlabeled 데이터를 모두 활용한다. 기본적으로 labeled 데이터에 대해서는 Pre-Training 과정에서 사용했던 방식을 똑같이 사용한다. 그리고 unlabeled 데이터에 대해서는 위에서 언급한 discrepancy 를 활용한다.

supervised loss는 흔히 아는 cross entropy이니 생략하고 unsupervised loss만 설명해본다면, 우선 해당 loss를 줄이는 것이 목적이다(m은 margin이다). 여기서 d 는 

다음과 같이 정의되며 p1, p2에서 나온 softmax 에 대하여 각각의 entropy의 차를 구하는 방식이다. 다시 unsupervised loss로 돌아와서, unlabeled 데이터들의 거리의 평균이 margin 보다 크다면 이를 0으로 수렴시켜 overfitting 을 방지했다고 한다. 이 부분을 다시 설명해보자면, 우선 우리는 unlabeled 데이터가 ood인지 id인지 모르는 상태이다. 만약 margin 이 1이고, unlabeled 데이터가 모두 ood라 해보자. 그러면 평균거리 d는 높게 나올 것이고 이건 우리가 바라는 바이기 때문에 loss가 0으로 설정된다. unlabeled 데이터가 모두 id일 경우를 생각해보자. 그러면 평균거리 d는 거의 0에 수렴할 것이고 그러면 loss가 1 발생할 것이다. 이것도 우리가 바라는 바로 거리를 0으로 설정했는데 왜 loss가 설정됐냐 함은, 우리는 기본적으로 loss를 줄이는 것이 목표이다. 그러면 id데이터에 대해서도 거리를 늘려 loss를 0으로 설정할 수 있기 때문이다(따로 설명은 없지만 개인적인 생각이다). 만약 절반이 ood, 절반이 id이고 ood 데이터에 대해 거리를 0.5, id에 대해서 0으로 했다면 전체 평균거리는 0.25가 될것이고 loss가 발생한다. 추가적으로 해당 방법을 설명하면서 margin m은 오버피팅을 방지한다 했으니 아마 내 이해가 맞다고 생각하는 중이다.

 

모든 훈련이 끝나고 inference과정 역시 2개의 classifier를 활용한 L1 distance를 구해 해당 데이터가 ood인지 구분한다

 

 

Result