Hello Computer Vision

비전공생의 GradNorm(2018) 논문리뷰 본문

딥러닝

비전공생의 GradNorm(2018) 논문리뷰

지웅쓰 2023. 12. 14. 01:13

해당 논문은 2018 ICML 에서 억셉되었고 가중치에 관련한 논문이다. 풀 제목은 GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks 이다.

https://arxiv.org/pdf/1711.02257.pdf

 

 

Introduction

많은 딥러닝 모델들이 높은 성능을 보이고 있지만 vision에 관련하여 full understanding 하기 위해서는 multit task들을 동시에, 효율적으로 할 수 있어야한다고 언급한다(단순히 한가지만 잘하는 것이 아니라 여러개를 동시에 잘해야 좋은 것이라는 가정이 깔려있는 거 같다). 그러나 multitask network를 훈련시키기는 어려운데 그 이유는 각기 다른 task들에 대해 적절히 balancing 하는 것이 어렵기 때문이다. 따라서 해당 논문에서는 GradNorm이라는 것을 소개하여 이러한 문제점을 해결하려고 한다. 해당 논문의 key  contribution은 다음과 같다.

1. An efficient algorithm for multitask loss balancing which directly tunes gradient magnitudes.

2. A method which matches or surpasses the performance of very expensive exhaustive grid search procedures, but which only requires tuning a single hyperparameter.

3. A demonstration that direct gradient interaction provides a powerful way of controlling multitask learning

 

The GradNorm Algorithm

우선 아래 figure를 보면 논문에서 어떤 것을 말하고 싶은지 알 수 있다.

왼쪽같은 경우 Loss에 대한 gradient 를 고려하지 앟ㄴ고 고정적으로 주게 된다면 어느 한 task를 network가 잘 푸는데도 불구하고 다른 task에 비해 더 많은 가중치를 줄 수 있고, 동일하게 줄 수 있다. 오른쪽을 보면은 이를 고려하여 푸는 난이도에 따라 gradient가 동일하게 흐름을 알 수 있다.

 

Multitask loss는 다음과 같이 이루어져있다. $L(t) = \sum{w_{i}(t)} L_{i}(t)$. 해당 논문에서의 표기들에 대한 설명이다. 이를 완전히 숙지하는 것이 중요하다.

하나부터 천천히 살펴보면, W는 cost에 대한 weight이다.  $G_{W}^{i}(t)$ 는 t시점의 i task에 대한  L2 norm이다. $G{W}(t)$ 는 t시점의 모든 loss에 대한 평균이다. $L_{i}(t)$ 는 t시점의 i task에 대한 loss ratio이다. 보통 0시점보다 t시점으로 갈 수록 loss는 낮아지므로 값이 높을수록 훈련이 잘 안되는 것을 의미하며 높은 가중치를 줄 필요있다. $r_{i}(t)$ 는 t시점에서의 전체 loss의 평균에 대한 i task의 비율을 나타내는데 여기서 L은 0시점에 대한 t시점의 비율을 나타내므로 나는 현재 t시점에서 다른 task에 비해 잘 풀어내고 있는가? 정도로 해석했다. 상대적으로 i task가 어느정도 비중을 차지하는지 알 수 있다.

 

결과적으로 t시점의 i task에 대한 loss를 보면은 다음과 같이 나타낼 수 있다.

전체 loss가 합으로 둘러쌓여있으니 이 식만 잘 들여다보면 된다. 이를 해석해보면

i task  gradient크기 - (전체gradient 의 평균 * i task gradient의 task진행상황정도?))

 

Result

흥미로운 논문인 거 같다. MTL을 한다면 시도해볼만한 거 같다.