일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |
- mocov3
- dann paper
- 컴퓨터구조
- 딥러닝손실함수
- Entropy Minimization
- CoMatch
- Pseudo Label
- simclrv2
- dcgan
- semi supervised learnin 가정
- CGAN
- shrinkmatch
- WGAN
- SSL
- cifar100-c
- 백준 알고리즘
- conjugate pseudo label paper
- GAN
- shrinkmatch paper
- BYOL
- mme paper
- Meta Pseudo Labels
- Pix2Pix
- 최린컴퓨터구조
- tent paper
- CycleGAN
- ConMatch
- remixmatch paper
- adamatch paper
- UnderstandingDeepLearning
- Today
- Total
Hello Computer Vision
Semi Supervised Learning에서 Pseudo label의 정확성 본문
여러 Semi Supervised Learning(SSL) 논문을 읽으면서 느꼈던 점은
"Softmax를 사용하는데 잘못된 클래스로 오분류하고 이 값이 threshold값을 넘으면 계속 오류가 나고 성능이 안좋아지지 않을까? " 라는 생각을 했고 이를 실험해봐야겠다는 생각을 했다.
기본 세팅은 FixMatch에 있는 세팅을 따랐고, Unlabeled 데이터에 대한 argmax softmax 값이 Threshold에 넘든 말든 얼마나 많이 오분류를 하는지 한번 살펴보았다.
acc_list = []
#correct_list = []
for i in range(epochs):
start = time.time()
labeled_iter = iter(labeled_trainloader)
unlabeled_iter = iter(unlabeled_trainloader)
model.train()
기본 Epoch 값은 1024이다. 그러나 시간이 오래걸려 1epoch동안의batch에서 정확도가 얼마나 되는지 측정했다.
for batch_idx in range(eval_step):
try:
inputs_x, targets_x = next(labeled_iter)
except:
labeled_iter = iter(labeled_trainloader)
inputs_x, targets_x = next(labeled_iter)
try:
(inputs_u_w, inputs_u_s), output_u = next(unlabeled_iter)
except:
unlabeled_iter = iter(unlabeled_trainloader)
(inputs_u_w, inputs_u_s), output_u = next(unlabeled_iter)
여기서 eval step은 1024이다. 즉 1epoch 동안 1024개의 배치를 돌며 여러번 데이터셋을 훈련한다(기존의 supervised 세팅과 많이 다르다). 원래는 unlabeled dataloader에 대하여 y를 굳이 정의할 필요 없지만 여기서는 실험을 위해서 정의해줬다.
batch_size = inputs_x.shape[0]
#print(inputs_x.shape)
#print(inputs_u_w.shape)
#print(inputs_u_s.shape)
inputs = interleave(
torch.cat((inputs_x, inputs_u_w, inputs_u_s)), 2*7+1).to(device) #(batch * mu, 3, 32, 32)
targets_x = targets_x.to(device)
output_u = output_u.to(device)
logits = model(inputs)
logits = de_interleave(logits, 2*7+1)
logits_x = logits[:batch_size]
logits_u_w, logits_u_s = logits[batch_size:].chunk(2)
차원을 이리저리 만져서 모델에 한번에 넣어주고 이를 다시 분리하는 과정을 거친다.
Lx = F.cross_entropy(logits_x, targets_x, reduction = 'mean')
pseudo_label = torch.softmax(logits_u_w.detach() / T, dim = -1)
max_probs, targets_u = torch.max(pseudo_label, dim = -1)
mask = max_probs.ge(threshold).float()
Lu = (F.cross_entropy(logits_u_s, targets_u, reduction = 'none')*mask).mean()
loss = Lx + Lu
loss.backward()
optimizer.step()
scheduler.step()
이제 weak augmentation을 적용한 unlabel데이터에 대하여 argmax값을 뽑는다. optimizer는 SGD를 사용한다.
model.zero_grad()
targets_u = targets_u.reshape(1, -1)
correct = targets_u.eq(output_u.reshape(1, -1).expand_as(targets_u)).float().sum()
ratio = correct /output_u.size(0)
print("Epoch : {}, Batch : {}, Loss : {:.2f}, Accuracy : {} / {}, ratio : {}".format(i+1, batch_idx+1, loss.item(), correct, output_u.size(0), ratio))
acc_list.append(ratio)
correct_list.append(correct)
#test_model = ema_model.ema
1개의 배치를 돌 때마다 448개의 배치에서 얼마나 맞혔는지 한번 살펴본다. 다시 한번 말하지만 이 개수는 threshold를 넘은 것이 아니라 그냥 오분류를 얼마나하는지에 대한 비율이다.
result = [acc.cpu().numpy() for acc in acc_list]
plt.plot(result)
결과는 위와 같다. 1 epoch가 돌 동안 unlabeled 데이터에 대해 40%정도의 정확도를 가지고 있다. 내 예상은 아주 낮은 비율을 기록할 줄 알았는데 완전히 빗나갔으며 1epoch동안에 여러번 데이터셋을 경험하면서 unlabeled 데이터에 대해서도 꽤 잘 맞추는 모습이다. 물론 여기서 epoch를 진행할 수록 정확도가 더 떨어질 수 있다는 가능성도 있으며 이는 더 실험해봐야겠지만 지금 실험으로는 1 epoch만으로도 모델이 unlabel 데이터에 대해서도 잘 예측하는 모습을 보여준다.
Reference
'Self,Semi-supervised learning' 카테고리의 다른 글
Label smoothing 효과 공부해보기 (0) | 2024.03.04 |
---|---|
SelfMatch(2021) 논문리뷰 (0) | 2024.02.14 |
semi supervised learning(준지도학습)에 사용되는 가정 및 방법 (0) | 2023.12.25 |
비전공생의 Whitening for Self-Supervised Representation Learning(2021) 논문리뷰 (1) | 2023.12.19 |
비전공생의 FreeMatch(2023) 논문리뷰 (1) | 2023.12.18 |