Hello Computer Vision

Test Time Augmentation 알아보기 본문

딥러닝

Test Time Augmentation 알아보기

지웅쓰 2023. 7. 10. 20:41

Test Time Augmentation, 즉 TTA 는 train 할 때 augmentation을 사용하는 것이 아닌 test시 augmentation을 수행해 augmentation된 이미지들의 예측값을 평균내어 최종적으로 예측하는 것이다.

 

inference시에 한 이미지에 대하여 여러 이미지들로 만들고 이를 평균해 예측하는 것이므로 Ensemble기법이라고 할 수 있다. 또한 Test시 한 이미지가 아닌 여러 이미지를 생성해 평균을 내 결과를 생성하므로 사용하지 않을 때보다 더 강건할 수 있다. 

 

일단 설명은 쉬운데 한번 코드로 수행해보았다.

import torch
import ttach as tta
import timm
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt

from google.colab import drive

transforms = tta.Compose(
  [
      tta.HorizontalFlip(),
      tta.Rotate90(angles = [0, 90]),
      tta.Multiply(factors = [0.7, 1])
  ]
)

ttach 라는 라이브러리가 따로있다. 내가 정의한 transforms을 본다면 transforms 과 똑같이 compose를 이용해 원하는 데이터증강을 넣을 수 있고, 이 모든 증강들이 수행된다는 것이 특징이다.

image = np.array(Image.open(img_path)) / 255
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).to(torch.float32)

fig = plt.figure(figsize=(20, 20))
columns = 3
rows = 3

for i, transform in enumerate(transforms):
    image_transformed = transform.augment_image(image)
    image_transformed = np.array(image_transformed.squeeze()).transpose(1, 2, 0)
    fig.add_subplot(rows, columns, i+1)
    plt.imshow(image_transformed)

plt.show()

반복문 보면 내가 정의한 transforms 을 넣어주면 안에 있는 데이터증강 방법들이 하나씩 돌아가면서 수행되는 방식같다.

이미지가 총 8개 나왔는데 내가 정의한 증강이 3개이므로 2^3 으로 8개가 나온 거 같다.

 


내용 참고

https://visionhong.tistory.com/26

'딥러닝' 카테고리의 다른 글

torch.topk 함수 공부해보기  (1) 2023.08.22
np.random.choice 공부해보기  (0) 2023.08.21
CAM, Grad-CAM 공부해보기  (0) 2023.06.30
GPU연산 DP, DDP 공부해보기  (0) 2023.06.29
torch.triplet margin distance loss 살펴보기  (0) 2023.05.04