Hello Computer Vision

torchvision.dataset 조금 뜯어보기 본문

딥러닝/파이토치

torchvision.dataset 조금 뜯어보기

지웅쓰 2023. 12. 21. 14:59

이번에 FixMatch관련하여 코드를 뜯어보던 중 배울 점도 있고 기록해야겠다는 생각을 했다.

 

1. torchvision dataset 을 상속받는 class를 생성할 때 들어가는 파라미터들을 줄 수 있다.

사실 너무나도 당연한 것일 수 있지만 그동안 코드를 뜯어보지 않고 기계적으로 수행하다보니 몰랐던 거 같다.

class CIFAR10SSL(datasets.CIFAR10):
    def __init__(self, root, indexs, train=True,
                 transform=None, target_transform=None,
                 download=False):
        super().__init__(root, train=train,
                         transform=transform,
                         target_transform=target_transform,
                         download=download)

예를 들어 CIFAR10 데이터셋을 불러오는 과정에서 내가 설정한 index만 불러와야하므로 커스텀 데이터셋을 만들어야 하는 상황이다. 따라서 기존에는 그냥 datasets.CIFAR10(root, train = True, transform = transform_) 라는 코드를 사용하여 데이터셋을 만들었겠지만 현재 내가 필요한건 class를 생성하고 커스텀 데이터셋을 만드는 것이다. 

따라서 커스텀 데이터셋을 만들기위해 class를 생성하고 datasets.CIFAR10을 상속받고  super().__init__() 안에 root, 훈련여부, download여부 등 부모클래스에 대해 들어가는 파라미터에 대해서 채워넣을 필요가 있다.

 

2. dataset.data와 dataset자체의 이미지, 라벨의 타입 차이

dataset = datasets.CIFAR10(root, train = True, download = True)

데이터셋 하나를 정의해보자

img, label = dataset[0]
print(type(img))
print(type(label))

<class 'PIL.Image.Image'>
<class 'int'>

dataset에 대하여 slicing 을 하게 된다면 각 데이터당 image, label을 품고 있고 각각의 type은 PIL, int임을 알 수 있다. 그래서 PIL 타입에 대하여 우리는 ToTensor를 활용하여 훈련에 사용되는 텐서로 바꿔주는 과정을 기계적으로 하게된다.

 

그런데 이번에 코드를 살펴보면서 조금 특이한 점을 발견했다.

print(len(dataset.targets))
print(len(dataset.data))

print(type(dataset.targets))
print(type(dataset.data))

50000
50000
<class 'list'>
<class 'numpy.ndarray'>

데이터셋 객체에 대해 직접적으로 슬라이싱을 통해 접근할 수도 있지만 이렇게 한번에 불러오는 방법도 있다. 조금 신기했던 점은 기존의 dataset 안에 있는 이미지들은 PIL타입이었지만 dataset.data로 접근한다면 numpy 타입으로 자동변환되었다는 것을 알 수 있다. 즉 pytorch의 dataset은 numpy를 기반으로 사용되는 것이다.

 

즉 훈련을 수행하기 전 데이터 전처리를 하는 과정을 거칠 때는 굳이 datatloader를 설정할 필요도 없고 dataset 에 대하여 하나하나 불러올 필요 없이 그냥 dataset.data / dataset.targets 를 활용하면 쉽게 접근할 수 있다는 것이다

 

훈련을 사용할 때 사용하는 DataLoader는 iterable 객체를 만드는 것이므로 dataset 자체를 감싸는 것이라고 할 수 있다.