일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |
- WGAN
- BYOL
- cifar100-c
- remixmatch paper
- 최린컴퓨터구조
- Pix2Pix
- ConMatch
- shrinkmatch
- CGAN
- dann paper
- mocov3
- 딥러닝손실함수
- Entropy Minimization
- shrinkmatch paper
- tent paper
- dcgan
- Pseudo Label
- UnderstandingDeepLearning
- CoMatch
- simclrv2
- conjugate pseudo label paper
- mme paper
- adamatch paper
- GAN
- SSL
- semi supervised learnin 가정
- Meta Pseudo Labels
- 백준 알고리즘
- CycleGAN
- 컴퓨터구조
- Today
- Total
Hello Computer Vision
torchvision.dataset 조금 뜯어보기 본문
이번에 FixMatch관련하여 코드를 뜯어보던 중 배울 점도 있고 기록해야겠다는 생각을 했다.
1. torchvision dataset 을 상속받는 class를 생성할 때 들어가는 파라미터들을 줄 수 있다.
사실 너무나도 당연한 것일 수 있지만 그동안 코드를 뜯어보지 않고 기계적으로 수행하다보니 몰랐던 거 같다.
class CIFAR10SSL(datasets.CIFAR10):
def __init__(self, root, indexs, train=True,
transform=None, target_transform=None,
download=False):
super().__init__(root, train=train,
transform=transform,
target_transform=target_transform,
download=download)
예를 들어 CIFAR10 데이터셋을 불러오는 과정에서 내가 설정한 index만 불러와야하므로 커스텀 데이터셋을 만들어야 하는 상황이다. 따라서 기존에는 그냥 datasets.CIFAR10(root, train = True, transform = transform_) 라는 코드를 사용하여 데이터셋을 만들었겠지만 현재 내가 필요한건 class를 생성하고 커스텀 데이터셋을 만드는 것이다.
따라서 커스텀 데이터셋을 만들기위해 class를 생성하고 datasets.CIFAR10을 상속받고 super().__init__() 안에 root, 훈련여부, download여부 등 부모클래스에 대해 들어가는 파라미터에 대해서 채워넣을 필요가 있다.
2. dataset.data와 dataset자체의 이미지, 라벨의 타입 차이
dataset = datasets.CIFAR10(root, train = True, download = True)
데이터셋 하나를 정의해보자
img, label = dataset[0]
print(type(img))
print(type(label))
<class 'PIL.Image.Image'>
<class 'int'>
dataset에 대하여 slicing 을 하게 된다면 각 데이터당 image, label을 품고 있고 각각의 type은 PIL, int임을 알 수 있다. 그래서 PIL 타입에 대하여 우리는 ToTensor를 활용하여 훈련에 사용되는 텐서로 바꿔주는 과정을 기계적으로 하게된다.
그런데 이번에 코드를 살펴보면서 조금 특이한 점을 발견했다.
print(len(dataset.targets))
print(len(dataset.data))
print(type(dataset.targets))
print(type(dataset.data))
50000
50000
<class 'list'>
<class 'numpy.ndarray'>
데이터셋 객체에 대해 직접적으로 슬라이싱을 통해 접근할 수도 있지만 이렇게 한번에 불러오는 방법도 있다. 조금 신기했던 점은 기존의 dataset 안에 있는 이미지들은 PIL타입이었지만 dataset.data로 접근한다면 numpy 타입으로 자동변환되었다는 것을 알 수 있다. 즉 pytorch의 dataset은 numpy를 기반으로 사용되는 것이다.
즉 훈련을 수행하기 전 데이터 전처리를 하는 과정을 거칠 때는 굳이 datatloader를 설정할 필요도 없고 dataset 에 대하여 하나하나 불러올 필요 없이 그냥 dataset.data / dataset.targets 를 활용하면 쉽게 접근할 수 있다는 것이다
훈련을 사용할 때 사용하는 DataLoader는 iterable 객체를 만드는 것이므로 dataset 자체를 감싸는 것이라고 할 수 있다.
'딥러닝 > 파이토치' 카테고리의 다른 글
[pytorch] 처리하고자 하는 배치사이즈가 다를 때 해결방법 (1) | 2023.12.29 |
---|---|
__name__ == '__main__' 쓰는 이유 (0) | 2023.12.21 |
torch.cuda.amp GradScaler 공부해보기 (0) | 2023.07.11 |
pretrained 모델에 fc레이어 추가해보기 (0) | 2023.05.14 |
F.cross entropy 구현 및 설명해보기 (0) | 2023.05.12 |