Hello Computer Vision

[pytorch] Nan, inf알아보기 본문

딥러닝/파이토치

[pytorch] Nan, inf알아보기

지웅쓰 2024. 2. 6. 23:24

아주 가끔 loss를 찍어보면 nan을 만날 수 있다. 내 경우는 loss 미분에 대해 learning rate를 낮추니 해결되었는데 이유에 대해서 한번 간단한 코드로 실험해보려고 한다.

 

a = torch.exp(torch.tensor(100))
b = torch.tensor(100)
print(a)
print(torch.log(a))
print(a / a)

softmax에 기본적으로 exp가 사용되니 적용해보았다.

tensor(inf)
tensor(inf)
tensor(nan)

결과는 이렇게 나오는데 즉 inf는 너무 높은 값이 나오면 컴퓨터에서 처리하지 못하는 것을 알 수 있다. 음의 방향으로 너무 크다면 -inf라는 결과가 나온다. 세번째에 나온 nan같은 경우 무한대 / 무한대 이런 값을 넣으니 나온 것이다.

a = torch.exp(torch.tensor(100))
b = torch.tensor(100)
print(1 / a)
print(a / 1)
print(1 / -a)
print(-a / 1)
tensor(0.)
tensor(inf)
tensor(-0.)
tensor(-inf)

이렇게 inf값으로 일반 실수와 연산을 0 혹은 inf값이 똑같이 나오고 inf 값끼리 연산을 했을 때 nan이 나오는 것을 알 수 있다. 

a = torch.exp(torch.tensor(100))
b = torch.exp(torch.tensor(200))
print(a+b)
print(a-b)
print(a*b)
print(a/b)

tensor(inf)
tensor(nan)
tensor(inf)
tensor(nan)

그러나 모든 연산에 대해서 nan을 하는 것은 아닌데 곱셈 연산과 덧셈연산에 대해서는 똑같이 inf를 유지하는 것을 알 수 있다.

 

즉 우리가 훈련하는 과정에서 loss값이 nan이 뜨는 이유 중 한가지는 너무 높은 값 혹은 너무 낮은 값에 대해 그 값과 추가적으로 나눗셈 한번 진행했기 때문에 뜨는 것이라 예상이된다(아마 softmax쪽에서 나눗셈 연산을 했을 거 같다, 뺄셈 연산을 제외한 이유는 훈련과정에서 inf - inf 가 거의 일어나지 않을 것이라고 생각하기 때문이다)

'딥러닝 > 파이토치' 카테고리의 다른 글

[pytorch] 행렬 계산  (0) 2024.02.17
[pytorch] 기본적인 tensor 조작해보기  (0) 2024.02.06
[pytorch] load state dict 후 업데이트  (0) 2024.01.22
CIFAR10-C 데이터셋 다루기  (0) 2024.01.20
torch stack, cat 차이  (0) 2024.01.17