일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |
- simclrv2
- GAN
- ConMatch
- Entropy Minimization
- CycleGAN
- remixmatch paper
- shrinkmatch paper
- CGAN
- mocov3
- dcgan
- 컴퓨터구조
- WGAN
- Pseudo Label
- adamatch paper
- UnderstandingDeepLearning
- shrinkmatch
- semi supervised learnin 가정
- Meta Pseudo Labels
- tent paper
- 딥러닝손실함수
- mme paper
- BYOL
- conjugate pseudo label paper
- 최린컴퓨터구조
- CoMatch
- dann paper
- Pix2Pix
- SSL
- 백준 알고리즘
- cifar100-c
- Today
- Total
Hello Computer Vision
DANN(2016) 논문리뷰 본문
논문의 풀 제목은 Domain-Adversarial Training of Neural Networks이다. DANN이라고도 불리며 Semi supervised Domain Adaptation에서 시초격인 거 같아서 리뷰해보려고 한다.
https://arxiv.org/pdf/1505.07818.pdf
Introduction
딥러닝에서 겪을 수 있는 흔한 문제는 train, test 데이터 셋에서 흔하게 일어날 수 있는 shift문제이다. 따라서 이러한 shfited distribution에서의 classifier가 잘 분류할 수 있도록 훈련하는 것이 Domain Adaptation(DA) 라고 할 수 있다. 여기서 target distribution은 모두 unlabeled 데이터이다.
기존의 방법들은 위의 목적을 수행하기 위해 fixed representation을 사용한다고 하는데 한번만 언급되고 이후에 언급되지는 않지만 한번 해석해보면, source distribution에 대해 fixed representation을 가지고 있다. 따라서 domain shift에 대해 잘 적응하지 못한다. 라고만 이해했다. 따라서 이 논문의 목적은 이러한 것보다는 domian invariant하게 잘 구분하는 것을 목적으로 한다. -->모델의 알고리즘이 domain shift를 눈치채지 못하고 잘 분류한다.
이를 수행하기 위해 저자는 domain invariance하게 구분을 잘하는 것이 목적이다. 이를 위해 2개의 classifier를 가지는데, label predictor와 domain classifier를 가진다. 여기서 둘의 역할이 다른데, label predictor의 loss는 줄이면서 domain classifier에 대한 loss는 최대로 하는 것이 목적이다. 이러한 동작 방식을 adversarial 이라고 한다(이 부분에 대해서는 사실 직관적으로 이해하면 굉장히 쉽다. 그러나 코드로 살펴보면 약간 헷갈리는 점이 있는데 이 부분은 나중에 설명해보려고 한다)
Related work
기존 연구에서는 source domain 에 있는 데이터들에 대해서 selecting 을 하거나 reweighting을 하는 과정을 거쳐 source distribution을 target distribution으로 맞추는 과정을 거쳤다고 한다. 이를 한번 풀어서 말해보면 source domain 중에서도 target domain과 비슷하거나 혹은 아주 먼 데이터가 있을 수 있는데 이러한 데이터들을 measure해 reweighting 및 selecting을 했다고 이해하였다.
이 논문에서도 space distribution을 맞추려고는 노력하지만 아예 다른 방식으로 modifying feature representation이라고 말한다. 위에서 말한 fixed representation과 상반된 표현이다.
Domain Adaptation
여기서 표기를 정리하는데 $D_{S}, D_{T}$ 가 있으며 target distribution에는 label이 없다. 다시 한번 언급하자면 여기서의 목적은 unlabeled target데이터들은 기존 source distribution과 다르고 test 과정에서 이를 잘 맞추기 위해 잘 학습하자는 것이 목표이다. 그리고 target data에 대한 expected risk는 다음과 같이 나타낼 수 있다.
$$ R_{D_{T}}(\eta) = Pr_{(x, y) ~ D_{T}} \left( \eta (x) \ne y \right)$$
여기서 $ \eta $ 는 x의 클래스라고 이해해도 무방할 거 같으며 물론 target dataset에는 y가 없지만 loss를 나타내기 위해 명시한 모습이다.
이전 논문들에서도 그렇고 이 논문에서도 Domain Adaptation 을 효과적으로 대응하기 위해 target error(target distribution에 대한 error)를 source error + distance between distribution으로 정의한다. 만약 모델이 distribution shfit에 대한 것을 잘 인식하지 못한다면 target error = source error와 같다고 할 수 있다. source error는 labeled 데이터들을 활용해 줄일 수 있으니 집중해야 할 부분은 distribution에 대한 거리인 것인데 이를 H-divergence를 활용한 최적화 알고리즘을 제안한다.
여기서 d는 distance를 가리키며 H는 hypothesis class를 가르킨다. 직관적으로 이해하면 그냥 클래스의 집합이라고 이해할 수도 있지만 이 공간은 너무 크면 안된다고 한다. 뒤에도 나오겠지만 이를 최적화하기 위해서 이 공간이 크면은 최적화하기 어렵다고 한다. 이렇게 설명하니 앞에서 말한 클래스의 집합이라고 하면 조금 이상하긴하다. 그냥 차원이라고 이해하고 $ \eta $ 는 각각의 distribution에 나온 x가 같은 공간에 위치할 확률이라고 보면 더 쉬운 거 같다. 위의 divergence식은 아래와 같이 다시 쓸 수 있다고 한다.
같을 확률 = 1 - 틀릴확률 이므로 이렇게 다시 쓴 듯하다. 이러한 식은 최적화 시키기 일반적으로 어렵지만 근사할 수 있다고하는데,
source data를 0으로 분류하고 target data를 1로 분류하는 새로운 n+N개의 데이터셋을 제시한다(한마디로 그냥 classifier로 훈련하는 것이 아니라 domain classifier를 추가로 붙이는 것에 대한 근거이다). 따라서 이를 근사해보면
$$ \hat{d}_{A} = 2 (1 - 2 \epsilon )$$
로 나타내고, A는 proxy A distance를 나타내는데 H-divergence를 근사했기 때문에 이렇게 나타낸 듯 싶고 여기서의 A역시 차원이 너무 크면 좋지 않을 거 같다.
따라서 위에서 정의한 target risk의 상한을 정의해보면
기존의 식에서 차원수에 대한 제약식과 $ \beta $가 추가로 들어간 것을 확인할 수 있는데 $ \beta $가 왜 추가로 들어갔는지 정확하게 모르겠지만 최적으로 분류했을 때 하한값을 가진다 라고 표기되어있긴한데 이 부분은 직관적으로 잘 이해가 되지 않는다. 이 부분에서 Rs 와 divergence간의 trade off가 생긴다고 한다. 그리고 DANN 의 알고리즘은 위와 같은 수식을 활용했다고 한다.
사실 뒤에 내용은 더 많은데 이정도만 정리하고 코드를 살펴보려고 한다. 공식코드는 아니지만 굉장히 star를 많이 받은 코드라서 가져왔다.
input_img = torch.FloatTensor(batch_size, 3, image_size, image_size)
class_label = torch.LongTensor(batch_size)
domain_label = torch.zeros(batch_size)
domain_label = domain_label.long()
class_output, domain_output = my_net(input_data=input_img, alpha=alpha)
err_s_label = loss_class(class_output, class_label)
err_s_domain = loss_domain(domain_output, domain_label)
먼저 source 이미지에 대해서 예측을 수행하고 해당 이미지에 대해 domain loss도 발생시키는 것을 확인할 수 있다. source domain label은 0으로 준다.
data_target = data_target_iter.next()
t_img, _ = data_target
batch_size = len(t_img)
input_img = torch.FloatTensor(batch_size, 3, image_size, image_size)
domain_label = torch.ones(batch_size)
domain_label = domain_label.long()
if cuda:
t_img = t_img.cuda()
input_img = input_img.cuda()
domain_label = domain_label.cuda()
input_img.resize_as_(t_img).copy_(t_img)
_, domain_output = my_net(input_data=input_img, alpha=alpha)
err_t_domain = loss_domain(domain_output, domain_label)
err = err_t_domain + err_s_domain + err_s_label
err.backward()
optimizer.step()
target 이미지에 대해서는 분류학습을 할 수 없으므로 domain loss만 발생시키는 모습이다. 그리고 최종적으로 이 loss를 최적화하는 것이 목적이다.
class ReverseLayerF(Function):
@staticmethod
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
output = grad_output.neg() * ctx.alpha
return output, None
domain loss에 적용되는 gradient reversal layer이다.
def forward(self, input_data, alpha):
input_data = input_data.expand(input_data.data.shape[0], 3, 28, 28)
feature = self.feature(input_data)
feature = feature.view(-1, 50 * 4 * 4)
reverse_feature = ReverseLayerF.apply(feature, alpha)
class_output = self.class_classifier(feature)
domain_output = self.domain_classifier(reverse_feature)
return class_output, domain_output
gradient의 방향이 반대가 된 것을 확인할 수 있는데 여기서 최적화 알고리즘은 모두 nll을 사용하므로 +가 되는 것이다. 그렇다면 loss를 작게하기 위해 최적화하면 할 수록 domain loss값에 대한 gradient는 반대로 이동하는 것이다. 그렇다면 여기서 domain loss가 커지는 방향으로 훈련되는 것이 어떤 것을 의미할까? source data는 1로 예측하고 target domain은 0으로 예측하는 것인데 그렇다면 결국 모델은 두 개의 data의 분포를 잘 구분하는 것 아닌가라는 생각을 하였다.
코드출처: https://github.com/fungtion/DANN/tree/master
설명참고: https://jaejunyoo.blogspot.com/2017/01/domain-adversarial-training-of-neural-2.html
'Domain Adaptation' 카테고리의 다른 글
APE(2020) 논문리뷰 (1) | 2024.03.14 |
---|---|
MME(2019) 논문리뷰 (0) | 2024.03.12 |