일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |
- remixmatch paper
- semi supervised learnin 가정
- Pix2Pix
- Meta Pseudo Labels
- adamatch paper
- conjugate pseudo label paper
- mme paper
- 딥러닝손실함수
- 컴퓨터구조
- simclrv2
- UnderstandingDeepLearning
- CGAN
- shrinkmatch paper
- CoMatch
- mocov3
- shrinkmatch
- CycleGAN
- Entropy Minimization
- SSL
- ConMatch
- GAN
- BYOL
- WGAN
- dcgan
- Pseudo Label
- 백준 알고리즘
- cifar100-c
- tent paper
- dann paper
- 최린컴퓨터구조
- Today
- Total
Hello Computer Vision
KL divergence, Cross Entropy 본문
GAN을 공부하다가 VAE개념을 몰라서 공부하는데 계속 나온 식이 바로
KL divergence! 그래서 정리도 할겸 이해한 내용 그대로 끄적여 볼라고 한다.
들어가기 전에 사전 지식이 있어야한다.
모든 정보에는 양이 있고 그 양은 모두 같은 값을 가지지는 않는다.
예를 들어 '해가 동쪽에서 뜬다' 라는 정보는 너무 흔하기 때문에 정보의 양을 거의 가지지 않고 '해가 서쪽에서 뜬다'라는 정보는 흔하지 않기 때문에 많은 정보의 양을 가진다.
표기를 해보자면 x라는 정보가 있고 y라는 정보가 있다.
이러한 정보의 확률 값을 p(x), p(y)라고 할 때
이에 대한 정보의 양을 h라 하였을 때
h(x) = f(p(x)), h(y) = f(p(y))
이렇게 임의의 함수 f를 이용하여 나타낼 수 있다.
그리고 다음과 같은 식들을 만족한다고 가정한다.
h(x, y) = h(x) + h(y)
p(x, y) = p(x) * p(y )
정보가(h) 겹친다고 해도 소실되거나 곱이 되지 않는다고 가정.
정보의 사건이 서로 영향이 없다고 가정.
따라서 이러한 가정으로 인해
h(x, y) = f(p(x, y)) = f(p(x) * p(y)) = f(p(x) + f(p(y))
(= h(x) + h(y))
이런 식을 도출할 수 있다.
f(p(x) * p(y)) = f(p(x) + f(p(y))
이 부분을 보면 가장 쉽게 떠오르는 식은 로그 함수이다.
log(x * y) = logx + logy 이기 때문에 함수 f는 로그함수이다.
그래서 나온 식이 h(x) = - log(p(x)) 이다.
(제일 중요한 부분, 기초 라고 생각!)
여기서 -가 붙었는데 붙은 이유는
로그 함수에 -를 붙여야 확률값이 높을수록 정보의 양이 적고
확률값이 낮을수록 정보의 양이 높은 그래프가 완성이된다.
이를 단조감소함수라고 한다.
예시)
x = 해가 동쪽에서 뜬다. y = 해가 서쪽에서 뜬다. 라고 가정하였을 때
p(x) = 0.9999 이고 p(y) = 0.0001이라고 해보자.
h(x,y) = -log(0.9999) - log(0.0001)
= 9.21044
이라는 값이 나온다.
정보의 평균 양을 구해보자면
E(h(x,y)) = p(x) * h(x) + p(y) * h(y)
=0.9999 * 0.00001 + 0.0001 * 9.21
= 0.0009
값을얻게되고
이를 일반적인 식으로 전개하면
이렇게 Entropy 식이 전개되어 나타난다!!
(h(x)의 평균을 H(x)라고 표현하였다)
(추가로 로그 밑에 2일 경우 정보의 단위가 bit가 되고 10이면은 nat(?) 가 된다)
이제 한번 KL divergence 로 들어가보자.
이를 짧게 요약해보면
원래 있던 분포와 내가 생각한(가정한)분포의 차이
라고 생각하면 된다.
위의 식은 내가 생각한 분포가 다 맞다고 생각하였다.
그러나 일반적으로 내가 생각한 분포는 실제와 같지 않기 때문에
다음과 같은 식으로 나타낼 수 있는 것이다.
p(x)는 기존의 분포이고 q(x)는 내가 생각한 분포이다.
예시)
면이 4개인 주사위를 생각해보자.
각각의 면이 나올 확률분포는 1/4로 동일하다.(원래 있던 분포)
그렇지만 나는 이 주사위의 정보를 잘 모르는 상태에서
1/2, 1/4, 1/8, 1/8로 생각하고 있다.(내가 생각한 분포)
이러한 정보를 누군가에게 보내기 위해서는 코딩이 필요한데
이를 0, 11, 111, 110으로 코딩하여 정보를 보낼라고 한다.
실제 최적의 코딩은 00, 01, 10, 11이다.
(코드의 길이가 길수록 정보를 보낼 때 cost는 증가한다)
이를 평균 정보의 양인 H(x)로 표현해보자면
내가 생각한 주사위 분포의 정보의 평균 양
= - {(1/4 * log 0.5) + (1/4 * log 0.25) + (1/4 * log 0.125) + (1/4 * log 0.125)}
= 2.25
원래의 주사위 분포의 정보 평균 양
4 *(- log 1/4) = 2
여기서 내가 착각한 분포와 기존의 분포와의 차이가 0.25인데 이를
cost라고 하고 이를 최소화 하고 기존의 분포를 따라갈 수 있도록 하는 것이 목적인 것이다.
이렇게 식을 세울 수 있고 이는
이렇게 정리할 수 있다.
어떤 식에서는 시그마 앞에 -가 붙고 로그 안에 있는 식이 거꾸로 되어있는 것도 봤는데 어차피 결과는 똑같다.
만약 x가 연속형 확률변수라면 시그마가 아닌
인티그럴이 붙어 적분을 하게 된다.
우리가 딥러닝을 할 때 손실함수를
Cross Entropy로 정의하는 경우가 많다.
이렇게 본다면 KL divergence 도 정의할 수 있도록 해야 하는거 아닌가?
라는 생각이 들지만
KL divergence식을 분해한다면 이렇게 기존의 분포와 Entropy 식으로 나누어지게 된다. 이에 대해 최적의 손실 값을 찾기 위해서는 미분이 필요한데 어차피 H(p)는 상수이고 미분을 하게 되면 사라지기 때문에 그냥 Cross Entropy를 사용하는 것이라고 할 수 있다.
2023.4.30 수정)
보니까 nn.KLDivLoss라는 것이 따로 있었다. https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
위에서 H(p)를 상수라고 생각했는데(관찰한 값들에 대한 정보량) 잘못 이해했다..
내용 출처 : https://www.youtube.com/watch?v=Dc0PQlNQhGY
https://theeluwin.postype.com/post/6080524
요약
KL divergence 는 기존의 분포와 내가 생각한 분포의 차이를 최소화 하기 위해 나타내는 식이다.
ps. 앞으로 로그에 마이너스가 붙은 식을 보더라도 무서워하지 말자.
'mathematics' 카테고리의 다른 글
Precision, Recall 구분 및 공부 (0) | 2023.07.21 |
---|---|
베이즈 정리(Bayes' Theorem) 이해하기 (0) | 2023.01.04 |
몬테카를로 시뮬레이션이란?(Monte Carlo Simulation) (0) | 2022.11.05 |
로지스틱 회귀(Logistics Regression)와 선형 회귀(Linear Regression) (1) | 2022.11.01 |
Maximum Likelihood Estimation(최대우도법) 요약 (0) | 2022.10.25 |