Hello Computer Vision

Semi Supervised Learning에서 Pseudo label의 정확성 본문

Self,Semi-supervised learning

Semi Supervised Learning에서 Pseudo label의 정확성

지웅쓰 2023. 12. 29. 19:38

여러 Semi Supervised Learning(SSL) 논문을 읽으면서 느꼈던 점은 

"Softmax를 사용하는데 잘못된 클래스로 오분류하고 이 값이 threshold값을 넘으면 계속 오류가 나고 성능이 안좋아지지 않을까? " 라는 생각을 했고 이를 실험해봐야겠다는 생각을 했다.

 

기본 세팅은 FixMatch에 있는 세팅을 따랐고, Unlabeled 데이터에 대한 argmax softmax 값이 Threshold에 넘든 말든 얼마나 많이 오분류를 하는지 한번 살펴보았다.

acc_list = []
#correct_list = []

for i in range(epochs):
  start = time.time()
  labeled_iter = iter(labeled_trainloader)
  unlabeled_iter = iter(unlabeled_trainloader)

  model.train()

기본 Epoch 값은 1024이다. 그러나 시간이 오래걸려 1epoch동안의batch에서 정확도가 얼마나 되는지 측정했다.

  for batch_idx in range(eval_step):
    try:
      inputs_x, targets_x = next(labeled_iter)
    except:
      labeled_iter = iter(labeled_trainloader)
      inputs_x, targets_x = next(labeled_iter)

    try:
      (inputs_u_w, inputs_u_s), output_u = next(unlabeled_iter)
    except:
      unlabeled_iter = iter(unlabeled_trainloader)
      (inputs_u_w, inputs_u_s), output_u = next(unlabeled_iter)

여기서 eval step은 1024이다. 즉 1epoch 동안 1024개의 배치를 돌며 여러번 데이터셋을 훈련한다(기존의 supervised 세팅과 많이 다르다). 원래는 unlabeled dataloader에 대하여 y를 굳이 정의할 필요 없지만 여기서는 실험을 위해서 정의해줬다.

 

    batch_size = inputs_x.shape[0]
    #print(inputs_x.shape)
    #print(inputs_u_w.shape)
    #print(inputs_u_s.shape)
    inputs = interleave(
        torch.cat((inputs_x, inputs_u_w, inputs_u_s)), 2*7+1).to(device) #(batch * mu, 3, 32, 32)
    targets_x = targets_x.to(device)
    output_u = output_u.to(device)
    logits = model(inputs)
    logits = de_interleave(logits, 2*7+1)
    logits_x = logits[:batch_size]
    logits_u_w, logits_u_s = logits[batch_size:].chunk(2)

차원을 이리저리 만져서 모델에 한번에 넣어주고 이를 다시 분리하는 과정을 거친다.

    Lx = F.cross_entropy(logits_x, targets_x, reduction = 'mean')

    pseudo_label = torch.softmax(logits_u_w.detach() / T, dim = -1)
    max_probs, targets_u = torch.max(pseudo_label, dim = -1)
    mask = max_probs.ge(threshold).float()

    Lu = (F.cross_entropy(logits_u_s, targets_u, reduction = 'none')*mask).mean()

    loss = Lx + Lu

    loss.backward()
    optimizer.step()
    scheduler.step()

이제 weak augmentation을 적용한 unlabel데이터에 대하여 argmax값을 뽑는다. optimizer는 SGD를 사용한다.

 

    model.zero_grad()

    targets_u = targets_u.reshape(1, -1)
    correct = targets_u.eq(output_u.reshape(1, -1).expand_as(targets_u)).float().sum()
    ratio = correct /output_u.size(0)
    print("Epoch : {}, Batch : {}, Loss : {:.2f}, Accuracy : {} / {}, ratio : {}".format(i+1, batch_idx+1, loss.item(), correct, output_u.size(0), ratio))
    acc_list.append(ratio)
    correct_list.append(correct)
    #test_model = ema_model.ema

1개의 배치를 돌 때마다 448개의 배치에서 얼마나 맞혔는지 한번 살펴본다. 다시 한번 말하지만 이 개수는 threshold를 넘은 것이 아니라 그냥 오분류를 얼마나하는지에 대한 비율이다.

 

result = [acc.cpu().numpy() for acc in acc_list]
plt.plot(result)

결과는 위와 같다. 1 epoch가 돌 동안 unlabeled 데이터에 대해 40%정도의 정확도를 가지고 있다. 내 예상은 아주 낮은 비율을 기록할 줄 알았는데 완전히 빗나갔으며 1epoch동안에 여러번 데이터셋을 경험하면서 unlabeled 데이터에 대해서도 꽤 잘 맞추는 모습이다. 물론 여기서 epoch를 진행할 수록 정확도가 더 떨어질 수 있다는 가능성도 있으며 이는 더 실험해봐야겠지만 지금 실험으로는 1 epoch만으로도 모델이 unlabel 데이터에 대해서도 잘 예측하는 모습을 보여준다.

 

Reference

https://github.com/kekmodel/FixMatch-pytorch