Hello Computer Vision

비전공생의 Distilating the Knowledge in a Neural Network(2015) 리뷰 본문

Self,Semi-supervised learning

비전공생의 Distilating the Knowledge in a Neural Network(2015) 리뷰

지웅쓰 2023. 4. 30. 16:15

https://arxiv.org/abs/1503.02531

 

Distilling the Knowledge in a Neural Network

A very simple way to improve the performance of almost any machine learning algorithm is to train many different models on the same data and then to average their predictions. Unfortunately, making predictions using a whole ensemble of models is cumbersome

arxiv.org

 

Knowledge distilation(이하 KD)을 공부하게 되면서 태초의 논문을 볼 필요가 있다고 생각해 읽어보았고 한번 정리해보려고 한다. 

Introduction

 

많은 생물(논문에서는 곤충)은 다양한 환경에 맞추어 발달하게 되는데 우리는 많은 분야,에 대해서 비슷한 모델과 비슷한 훈련과정을 거치고 있다고 한다. 또한 성능을 내는 가장 간단하면서 좋은 방법은 모델들을 앙상블 하는 것이라 말하지만 이것은 굉장히 많은 연산량과 이것을 사람들에게 배포할 때는 어려움이 있다고 말한다. 

 

따라서 무거운 모델들(cumbersome model)을 작은 모델에 지식을 나눠주는 distilation을 제시한다. 이전 논문에서 이러한 방법이 효과가 있었다고 말한다. 해당논문  

직관적으로 생각해보았을 때 어떻게 큰 모델이 작은 모델한테 지식을 나누어주지라고 생각할 수 있다. 가중치를 주는 것인가?라고 직관적으로 생각할 수 있지만 모델의 구조가 다르다면 가중치를 같이 주는 것으로 모델이 향상될 수 있을지는 미지수이다. 여기서는 지식을 나누어 준다는 것을 mapping을 가르쳐준다는 것으로 표현한다(learn mapping from input vectors to output vectors). 예를 들어 크고 복잡한 모델이 분류 문제를 풀 때 BMW에 대한 확률값을 내뱉을 때 BMW를 제외한 확률들은 굉장히 낮을 것이다. 그러나 꼭 정답인 BMW뿐만 아니라 truck에 대한 확률 정보들도 도움이 될 수 있다고 말하며 'soft target' 이라는 것을 사용한다고 한다(hard target의 반댓말). 이러한 soft target 의 entropy가 높을 수록 많은 정보를 받을 수 있고 도움이 된다고 한다. 그런데 보통 크고 복잡한 모델일수록 softmax 를 적용하기 전의 logit값은 정답값에 대해 높을 것이거 정답아닌 logit값에 대해는 낮을텐데 따라서 이를 보완하기 위해 Temperature(T) 값을 도입해 이러한 값을 완만하게 한다고 한다. 

이를 시각화하면 위 이미지와 같다. temperature값이 높을 수록 그래프가 더 완만해지며 1이면은 기존과 같고, 낮을수록 한가지 값에 대한 확실성을 더 높이 판단한다.

 

KD vs Transfer learning

우선 나는 KD와 transfer learning에 대한 개념이 헷갈렸었는데 다시 한번 정의하려고 한다.

Transfer learning 은 일단 기존에 훈련 시킨 큰 모델이 있다면 이를 다른 모델에도 학습 가중치를 적용하고, 추가적인 부분만 바꾸는 것이라고 생각할 수 있다. 예를 들어 100개의 클래스 분류 데이터셋으로 훈련을 진행하고 내가 진행할 문제는 개, 고양이 분류 문제라고 했을 때 마지막 분류 fc layer만 추가로 변경해서 훈련할 수 있다. 모델의 크기는 변하지 않는다.

 

그러나 KD 의 경우 크고 복잡한 모델을 압축한 것이라고 볼 수 있다. 다양한 방법들을 활용하여 student model이 teacher model의 값을 추론할 수 있도록 하며 이는 모델의 경량화를 한 것이라고 볼 수 있다. 

 

Distilation

이를 수식으로 나타낸다면 다음과 같다. q는 확률, z는 logit, T 는 temperature를 나타낸다. 

전체적인 구조를 나타내면 위 이미지와 같다. 단순히 soft target을 사용하는 것이 아니라 soft target, hard target을 사용하고 이 loss들을 합한다고 한다. 논문에 나온 결과로는 hard target에 대한 loss 비중을 적게 하면 결과가 더 좋았다고 한다.

 

논문에서는 각 모델에 대해 최적화하는 수식이 있긴한데 정확하게  이해를 하지 못했으므로 패스하려고 한다. 그러나 최종적으로 우리가 최소화해야하는 것은 1/2(z - v) ^ 2값이다. z는 student model의  logit값이고, v는 teacher model의 logit값이다. 추가적인 가정은 각 logit값의 평균이 0이다. T값이 낮다면은 model은 많은 다른 값들은 무시할 것인데 이는 데이터에 노이즈값이 많을 때 유용할 수 있다. T값이 높다면 다른 정보들 또한 중요하게 생각한다는 것인데 이는 경험적으로 고려해야하고 선택해야한다고 말한다.

 

이에 대한 실험으로  MNIST데이터셋(60,000)으로 훈련을 해보았을 때 teacher model이 67개의 error를 냈을 때 student model은 146개의 error를 냈다고 한다. 심지어 데이터셋에서 3을 제외하고 훈련했을 때도 206개의 error밖에 나지 않았다고 한다. 추가적인 setting으로는 regularized, bias setting 이 중요하다고 한다.

 


이 논문에서는 추가로 audio data와 ensemble specialist 에 대해서도 이야기를 나누는데 제 목적은 KD였기 때문에 여기까지 기록할 예정! ensemble specialist 에 대해 간단하게 말하면 1개의 base model과 여러개의 model(몇개의 헷갈리는 class에 대해 특화된)을 앙상블 하는 기법이라고 한다.