Hello Computer Vision

F.cross entropy 구현 및 설명해보기 본문

딥러닝/파이토치

F.cross entropy 구현 및 설명해보기

지웅쓰 2023. 5. 12. 19:09

cross entropy이 놈은 공부해도 공부해도 계속 알쏭달쏭하다. 그래서 지금 나름 정리했다고 생각했는데 또 나중에 궁금해 할 수도 있긴한데 일단 지금 기준으로 내가 정리한 걸 적어보려고 한다.

y = torch.FloatTensor([0, 0, 0, 0, 1, 0, 0, 0, 0, 0])
x = torch.Tensor([0.1, 0.03, 0.05, 0.2, 0.9, 0.0, 0.1, 0.2, 0.12, 0.03])
def cross_entropy_loss(x, y):
    delta = 1e-7
    return -torch.sum(y*torch.log(x+delta))
def softmax(a):
    #c = torch.max(a)
    exp_a = torch.exp(a)
    sum_exp_a = torch.sum(exp_a)
    y = exp_a / sum_exp_a

    return y

기본적으로 F.cross_entropy 에서는 softmax가 취해서 나오므로 비슷하게 구현하기 위해 softmax까지 구현했다.

class Loss:
    def __init__(self):
        self.loss = None
        self.x = None
        self.y = None

    def forward(self, x, y):
        self.y = y #정답 레이블
        self.x = softmax(x)
 #logit 값을 softmax 변환
        self.loss = cross_entropy_loss(self.x, self.y)
        return self.loss
loss = Loss()

print(loss.forward(x, y))

print(F.cross_entropy(x, y))
tensor(1.6139)
tensor(1.6139)

그리고 구현한 cross entropy와 F.cross entropy를 비교했을 때 정확히 값이 똑같은 것을 확인할 수 있다.

내가 사실 이렇게 구현까지 한 이유는 과연 F.cross_entropy는 정답값이 아닌 다른 값에대해서도 손실값을 취할까인데 여기서는 아니다이다. 다른 분들 글을 보면은 이게 정확히 안나와있어서 계속 헷갈렸는데 위 코드에서는 아니다. 그리고 또 한가지 F.cross entropy를 사용할 때 원핫인코딩으로 주지 않고 라벨 값으로 주어도 원핫인코딩을 한 결과와 같은 loss를 준다.

만약 target 값이 one-hot 이 아니라 예를 들어 label smoothing 을 줄 경우 그만큼의 추가 loss가 발생한다.

y1 = torch.FloatTensor([0, 0, 0, 0.05, 0.95, 0, 0, 0, 0, 0])
x1= torch.Tensor([0.1, 0.03, 0.05, 0.2, 0.9, 0.0, 0.1, 0.2, 0.12, 0.03])

y2 = torch.FloatTensor([0, 0, 0, 0, 1, 0, 0, 0, 0, 0])
x2 = torch.Tensor([0.1, 0.03, 0.05, 0.2, 0.9, 0.0, 0.1, 0.2, 0.12, 0.03])

print(F.cross_entropy(x1, y1))  # tensor(1.6489)
print(F.cross_entropy(x2, y2))  # tensor(1.6139)

근 1,2일간 앓았던 문제 어느정도 해결이다. 시간이 지나면 또 헷갈리겠지만...