Hello Computer Vision

[pytorch] 데이터셋 합치기 본문

딥러닝/파이토치

[pytorch] 데이터셋 합치기

지웅쓰 2024. 2. 17. 19:25

corrupted 된 이미지들을 다루다보면은 여러 개의 데이터셋을 다루다보니 코드가 더러워질 수 있다.

 

이에 대해서 코드의 간편성을 위해서는 데이터셋을 합칠 수 있다.

 

mnist = datasets.MNIST(root = './sample_data', train = True, download =True, transform = transforms.ToTensor())
cifar = datasets.CIFAR10(root = './sample_data', train = True, download = True, transform = transforms.ToTensor())

concat = torch.utils.data.ConcatDataset([cifar, mnist])

print(len(concat))
print(len(mnist))
print(len(cifar))

dl = torch.utils.data.DataLoader(concat, batch_size = 64)
110000
60000
50000

이렇게 합친 concat 데이터의 개수는 총 110,000개인 것을 확인할 수 있고 데이터가 들어가는 위치는 리스트 인덱스 순으로 들어가있다. 예를 들어 ConCat 인자로 [cifar, mnist] 순으로 들어가있으니 처음 50,000개의 데이터는 cifar이고 나머지 데이터는 mnist인 것이다. 그렇다면 이것을 iterator로 돌리면 어떻게 될까?

for i, (x, y) in enumerate(dl):
  if x.shape[1] == 1:
    print('mnist!')
    print(i)
    break
  if i % 50 == 0:
    print('not yet')
not yet
not yet
not yet
not yet
not yet
not yet
not yet
not yet
not yet
not yet
not yet
not yet
not yet
not yet
not yet
not yet
RuntimeError: stack expects each tensor to be equal size, but got [3, 32, 32] at entry 0 and [1, 28, 28] at entry 16

아쉽게도 합친 데이터셋의 크기가 다르다면 중간에 에러가 난다. 따라서 만약 합치려는 데이터셋의 크기가 다르다면 같도록 전처리를 해주어야한다. 

 

++ DataLoader의 인자로 shuffle = True를 준다면 차례로 들어가지 않는다.