Hello Computer Vision

Consistency Regularization 설명 본문

Self,Semi-supervised learning

Consistency Regularization 설명

지웅쓰 2024. 4. 28. 16:19

Consistency regularzation은 SSL 비롯하여 GAN에서도 사용된 기법으로 이해하기는 쉽지만 깊이 들어가보려고 한다. 해당 내용은 Kevin P. Murphy 의 Probabilistic Machine Learning의 내용을 바탕으로 제 개인적인 의견을 추가하여 작성했습니다. 

 

Consistency Regularzation 이 무엇을 수행하는지는 간단하다. 데이터 (여기서 말하는 데이터는 다 이미지 데이터)에 대해 일정한 노이즈 및 증강을 가해도 데이터의 본질 및 결과가 바뀌지 않는다는 것이다. 예를 들어 강아지 이미지에 대해 회전, 혹은 crop을 하더라도 강아지이다. 당연할 수 있으나 우리가 다루고 있는 모델은 이를 모를 수 있으니 이러한 부분에 대해 regularization 을 주는 것이 핵심이라고 할 수 있다. 이것을 수식으로 쓰면은 

$ L = || p_{\theta} (y | x) - p_{\theta} (y | x') ||^{2}  $ 

이렇게 작성할 수 있다. 수식 작성의 편리함을 위해 한개의 데이터에 대해서만 loss를 작성했다. 해당 loss는 단독으로는 사용할 수는 없고 주로 Semi Supervised Learning 쪽에서 supervised loss와 같이 사용되는데, 그 이유는 두 개의 output이 같기만 하다고 모델을 잘 학습할 수 없기 때문이다. 예를 들어 다음과 같은 [0.2, 0.8], [0.4, 0.6]  두 개의 output에 대해 우리는 결과적으로 모델이 정답인 label 확률 값이 1로 수렴하도록 하는 것이 목적이다. 그러나 위의 loss만 최적화 하게 된다면 model은 해당 데이터에 대해 loss를 최소화 되도록 [0.5, 0.5] 라는 output을 산출하는 것인데 이는 우리의 목적이 아니다. 

따라서 supervised loss와 혼합해서 사용하고, 추가적으로 loss를 조절하는 $ \lambda $ 를 넣는 것이 일반적이다. 

 

해당 방법에 대한 variation은 굉장히 많은데 그 이유는 어떠한 증강을 사용하냐에 따라 성능이 달라질 수 있기 때문이며 이를 domain specific 하다고 할 수 있다. 추가적으로 Mean teacher 논문에서는 모델 한개에 두개의 데이터를 넣는 것이 아닌 2개의 모델을 사용하는데, 1개는 gradient가 흐르지 않고 EMA 방식으로 gradient를 받는 방법을 택한다. 

 

 

그렇다면 Consistency Regularization을 수행하는데 KL divergence 혹은 Squared loss, 두 가지 모두 적용할 수 있는데 어떤 것을 사용해야할까?  

해당 그래프를 보면 x축은 노이즈 및 증강이 적용된 데이터의 logit(시그모이드 함수 들어가기 전의 함수 값)을 나타내고 y축은 노이즈 데이터가 정상 데이터인지 판별하는 binary loss를 뜻한다. 노이즈 데이터의 logit이 너무 커진다는 것은 모델이 해당 데이터가 정상 데이터라고 상당한 confidence를 가지고 있다고 볼 수 있고 logit 이 너무 낮아지는 것은 해당 데이터가 정상 데이터가 아니라는 confidence를 가지게 된다. 

이 두개의 데이터는 원래 같은 데이터인데, logit값이 왼쪽으로 간다는 것은 모델이 노이즈 데이터에 대해 잘 파악하지 못했다는 것인데 그렇다는 것은 모델이 이 둘을 다르다고 생각한다는 것이다. 이쪽의 gradient를 본다면 0이라고 볼 수 있는데,  모델이 같은 input이 다르다고 생각할 수록 업데이트를 수행하지 않는 것이며, 다르게 말하면 모델의 output이 unstable하다면 업데이트 하지 않는다는 것이다. 같은 데이터인데 다르게 예측하는 것 자체를 모델이 unstable하다고 정의하는 거 같다.

그러나 regularizer loss로 kl divergence를 선택할 경우 다르게 판단할 경우 loss값이 발산하는 것을 알 수 있다. 위에서 설명한 것처럼 모델의 output이 unstable할 때 더 많은 update 를 진행하는 것이라고 할 수 있다. 이 그래프를 통해서 어떤 것이 regularizer loss가 좋다고는 말할 수는 없는 것이 우선 우리가 주로 수행하는 것은 multi class예측이고 마지막에 softmax를 활용하는데 여기서는 binary 에 대해 sigmoid를 사용했기 때문에 단적으로 비교하기는 어렵다. 따라서 주어진 task에 대해 직접 실험을 해본 후 어떤 것이 더 좋은지 판별할 필요가 있다

 

여기까지가 원래 책의 설명이고 이후의 내용은 내 의견이다. 과연 Consistency regularization을 수행할 때 증강을 어떻게 적용해야할까? SSL 에서는 strong, weak 증강을 나누어 consistecy를 수행한다. 그러나 증강의 조합은 무수히 많으며 이것을 하나하나 수행해 어떤 조합이 좋은지는 사실 불가능하다. 실험결과 SSL에서 weak 만 사용할 경우 성능이 좋지 않았지만 Mean teacher에서의 결과를 보면 2개의 모델을 사용할 때는 엄청나게 많은 증강을 사용하지 않았으며 translation, horizontal flip만을 사용하였다. 그리고 내 실험결과에서 SSL에서 한 개의 모델을 사용했을 때의 Train loss는 weak만 사용했을 때가 더 낮았다(물론 Test accuracy는 낮다). 그렇다는 것은 별다른 증강이 적용되지 않았을 때 한 개의 모델만 사용했을 때는 같은 이미지에 대해서 수행할 때도 있었다는 것이며 이는 update가 이루어지지 않으니 unlabeled 에 대한 정보를 활용하지 못하는 것이라고 할 수 있다. 즉 weak만 사용할 거면 한개의 이미지에는 적용을 하지 말든지, 2개의 모델을 사용하든지 결정을 할 필요가 있다는 것이다.

 


Reference

Kevin P. Murphy , Probabilistic Machine Learning

Mean teacher, https://arxiv.org/pdf/1703.01780