Hello Computer Vision

비전공생의 CCSSL(2022) 논문리뷰 본문

Self,Semi-supervised learning

비전공생의 CCSSL(2022) 논문리뷰

지웅쓰 2023. 8. 2. 18:44

논문의 풀 제목은 Class-Aware Contrastive Semi Supervised Learning이다.

https://arxiv.org/pdf/2203.02261.pdf

 

Introduction

모델을 훈련시킬 때 중요한 요소가 있는데 바로 Raw data에 대한 접근이다. Real world에서는 Raw data가 많을 것이며 이 중 일부는 훈련에서 경험하지 못한 클래스들의 데이터들도 있을 것이다. 이것을 out of distribution data라고 하며 훈련과정에서 경험한 데이터를 in distribution data라고 한다.

 

이를 잘 보여주는 사진이다. SSL을 통해 raw data 에 대한 대비 없이 훈련하게 된다면 confirmation bias, 즉 확증편향이 일어나게 되는데 이는 pseudo label에 대해 지나친 확신을 가지고 있는 것을 말한다. 확증편향에 대한 결과로 실제 추론 과정에서 만나보지 못한 클래스에 대해 잘못된 클래스로 확신해버리는 결과를 초래한다. 따라서 이 논문에서는 이러한 문제점을 해결하기 위해 class aware contrastive module을 추가해 이를 완화하려고 한다. 추가로 군집화를 실행하는데 있어 방해가 되는 노이즈에 대해서도 re-weighting 방법을 통해 이를 줄이려고 한다.

논문의 contribution 은 다음과 같다.

1. 새로운 SSL방법을 통해 out of distribution data 대처

2. 확증편향 문제에 도움이 된다. 이는 real world 문제에 있어 실용적이다.

3.  CCSSL을 통해 SOTA 달성에 도움이 된다.

 

Method

 

CCSSL 의 전체적인 프레임워크는 다음과 같다. Pseudo label을 생성하고 이를 unsupervised loss를 계산하는 것은 FixMatch와 동일하다고 하며 supervised loss도 이와 같다고 하며 따로 언급을 하지는 않는다. 그리고 저자들이 가장 중요하게 생각하는 class aware contrastive module을 설명한다.

이 논문에서는 여러가지 matrix가 나오는데 제일 먼저 contrastive matrix을 설명해보자면, 배치 안에 있는 각각의 이미지들은 strong augmentation을 거쳐 2개씩 positive pair를 갖게되며, 이를 encoder를 통해 low dimensional embedding인 z를 얻게 된다. 이러한 각각의 값들을 내적하고 이를 matrix로 나열하게 되는데 수식으로 나타내면 다음과 같다. 여기서 S 행렬은 각 이미지의 embedding vecotor를 내적한 것을 말하며 즉 similiarity를 나타내며, W행렬은 pseudo label과 augmentation에 대해서 positive, negative sample을 나타내는 행렬이다.

이렇게 자기 자신과 positive pair끼리는 1을 갖게되며 그 외에는 negative sample이므로 0을 갖게되는 행렬을 얻게된다. 그리고 이러한 행렬에서 loss를 산출하게 된다면, 

infoNCE loss를 사용할 수 있다. 이러한 방식을 통해 positive pair는 당기고 negative pair는 밀어내며 모델의 표현을 더 배울 수 있으나 이러한 방식은 구체적인 클래스를 필요로 하는 classification task 와 양립할 수 없다고 한다. 이러한 문제점을 해결하기 위해 저자는 class를 고려하는 class aware contrastive module을 만든다.

 

저자들은 위에서 언급한 문제점을 해결하기 위해 positive pair를 같은 클래스의 이미지라고 생각하며 pseudo label의 unsafe함을 방지하기 위해 T 임계값을 설정한다.  이를 적용해 새롭게 정의한 행렬은 다음과 같다.

이전과 비슷하지만 조금 다른 부분은 같은 증강에서 나온 이미지라 하더라도 해당 이미지의 feature 값이 임계값을 넘지 못하면 이미지의 같은 클래스라고 여기지 않는데 이는 pseudo label 의 unsafe, confirmation bias를 완화한 기법이다.  이를 위 이미지에서 나온 CCSSL의 프레임워크 이미지를 본다면 class aware matrix에서는 조금 더 엄격하게 행렬에 대해 positive 를 설정함을 알 수 있다.

 

그리고 여기서 하나의 문제점을 추가로 제시하는데, 위 행렬을 정의하기 위해 사용한 T는 out of distribution data를 정의한 것뿐만 아니라 pseudo label이 가지는 confirmation bias를 억제함을 알 수 있는데 이는 문제가 될 수 있다고 한다(정확히 어떤 문제인지는 언급하지 않지만 T라는 한 값을 이용해 두가지 요소의 효과를 보는건 아무래도 문제가 있을 것 같다는 생각은 든다). 따라서 저자는 re-weighting 을 통해 T가 가진 역할을 나누려고  한다. 

해당 weight 인 q는 confidence score를 나타내는데, weak augmentation을 통해 나온 분포값이 [0.9, 0.1] 이라면 confidence score 는 0.9가 된다. 이러한 re-weighting 은 pseudo label의 confirmation bias를 조금 완화하는 것으로 볼 수 있는데, 어떤 한 이미지의 값이 잘못예측했지만 높은 확률로 부여받았을 때 re-weighting을 통해 조금 완화될 수 있다. 그리고 이렇게 만들어진 행렬과 각각의 이미지를 strong augmentation 2가지를 통해 나온 embedding matrix와의 Cross entropy를 구한다. 

여기서 P는 positive sample의 cardinality라고 한다. 해당 Loss를 정확히 이해를 하지는 못했지만 한번 이해한대로 써보자면 첫번째 loss term 에서는 augmentation된 이미지들에 대해서 손실값을 계산하며,  두번째 loss term에서는 행렬을 활용해 positive sample들끼리의 손실값을 계산하지만 이는 pseudo label의 결과이기 때문에 어느 정도 보정해주는 weight를 추가하는 것으로 이해하였다(언제든 태클해주시면 감사하겠습니다). 

전체 loss는 다음과 같다. 

위에서도 말했다시피 supervised, unsupervised loss는 따로 언급되어 있지 않으며 FixMatch와 같이 동일하게 사용되었다고 한다. 

 

Result

noise가 많은 데이터셋인 CIFAR100에서 CCSSL을 적용한 FixMatch에서 좋은 성능을 보였다고 말한다. 

real world dataset으로 실험해보았을 때 CCSSL을 추가한다면 더 좋은 성능을 나오는 것을 확인할 수 있다.

 

그리고 많은 역할을 하는 임계값 T값에 대해서는 0.6~0.9까지 비슷한 성능을 보임을 알 수 있다. 

 

각 loss term에 대해서 부여되는 가중치는 따로 언급되지 않는다.

 


예전 SSL방법에서는 고정된 데이터셋에 대해서 어떻게 성능을 올릴까? 가 문제였다면 최근 SSL논문을 보면 real world에 적용해서 성능을 올리는 것에 주력하는 것을 알 수 있다. 해당 논문도 조금 어려움 있었지만 재밌게 읽었다.