일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |
- ConMatch
- mme paper
- Pix2Pix
- cifar100-c
- remixmatch paper
- 최린컴퓨터구조
- 컴퓨터구조
- Meta Pseudo Labels
- simclrv2
- shrinkmatch paper
- CycleGAN
- SSL
- CGAN
- Entropy Minimization
- dann paper
- CoMatch
- dcgan
- UnderstandingDeepLearning
- tent paper
- BYOL
- conjugate pseudo label paper
- mocov3
- 백준 알고리즘
- Pseudo Label
- WGAN
- semi supervised learnin 가정
- GAN
- shrinkmatch
- 딥러닝손실함수
- adamatch paper
- Today
- Total
Hello Computer Vision
CIFAR10-C 데이터셋 다루기 본문
이번에 실험하면서 CIFAR10-C 데이터셋을 다루게되었는데 시행착오도 많아가지고 공유해보려고 한다.
우선 https://zenodo.org/records/2535967 사이트에서 CIFAR10-C 데이터 다운로드 링크를 복사 후 터미널에
wget {dataest link} 를 하면된다. colab환경이라면 !wget하면 된다
그렇다면 tar를 확장자로 하는 압축파일을 다운로드 받게 되는데 압축을 푸는데 다양한 방법이 있지만
import tarfile
tar = tarfile.open('/content/CIFAR-10-C.tar?download=1', 'r')
tar.extractall()
tar.close()
이렇게 한다면 압축 파일을 해제할 수 있다(위 환경은 colab환경이다)
그렇다면
이렇게 많은 corruption들이 들어간 파일을 얻게된다. npy파일은 numpy를 이용한 파일이고 이를 load해본다면,
for name in corruption_list:
file_ = os.path.join('/content/CIFAR-10-C',name+ '.npy')
data_array = np.load(file_)
print(data_array.shape)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
이렇게 얻을 수 있다. 물론 파일 중 label도 있다. 여기서 알 수 있는 점은 각각의 증강마다 50,000개의 이미지가 있다는 것이다. 즉, 이 CIFAR10-C를 100%활용하기 위해서는 증강마다 정확도를 따로 기록해야한다는 것이다(근데 나는 바보같이 한개의 증강만을 활용했다...).
https://github.com/tanimutomo/cifar10-c-eval/blob/master/src/test.py
위 github에서 이를 잘 활용하고 있으니 참고하면 좋다.
그렇다면 위 github에서 dataset을 어떻게 정의했는지 보자
class CIFAR10C(datasets.VisionDataset):
def __init__(self, root :str, name :str,
transform=None, target_transform=None):
assert name in corruptions
super(CIFAR10C, self).__init__(
root, transform=transform,
target_transform=target_transform
)
data_path = os.path.join(root, name + '.npy')
target_path = os.path.join(root, 'labels.npy')
self.data = np.load(data_path)
self.targets = np.load(target_path)
def __getitem__(self, index):
img, targets = self.data[index], self.targets[index]
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
targets = self.target_transform(targets)
return img, targets
def __len__(self):
return len(self.data)
root는 기본 root, name은 증강을 말하며, npy 파일에 대하여 load함을 알 수 있다.
추가적으로 cifar10, cifar10-c 데이터셋에 대하여 각각의 index의 이미지들이 같을까? 이거에 대해서는 시각화 해봤을 때 다르다. 그리고 보통 train할 때 shuffle하는 경우가 다르므로 label값을 따로 불러내어 (labels.npy파일에 있다) 정확도를 측정해야한다.
++2024.02.20 수정
만약 1개의 corrupted dataset 을 만들 때면 위와 같이 수행하면 된다. 그러나 어떤 논문들에서는 다수의 corrupted이미지를 넣어 구성하는 경우도 있다.
class CIFAR10_C_all(datasets.VisionDataset):
def __init__(self, root: str, corruption_list, indexs, transform=None, target_transform=None):
super(CIFAR10_C_all, self).__init__(root, transform=transform, target_transform=target_transform)
data_stack = np.empty([0, 32, 32, 3])
target_stack = np.empty([0])
self.indexs = indexs
for i, name in enumerate(corruption_list):
data_path = os.path.join(root, name + '.npy')
target_path = os.path.join(root, 'labels.npy')
indexs = np.random.choice(self.indexs, 2300, False)
data = np.load(data_path)
data = data[indexs]
target = np.load(target_path)
target = target[indexs]
data_stack = np.append(data_stack, data, axis=0)
target_stack = np.append(target_stack, target, axis=0)
self.data = data_stack.astype(np.uint8)
self.target = target_stack.astype(np.int64)
def __getitem__(self, index):
img, target = self.data[index], self.target[index]
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.data)
위와 같이 구성하고 인자로 원하는 corruption list를 넣으면 된다. 코드 설명은 간단하니 따로하지 않으려고 한다.
'딥러닝 > 파이토치' 카테고리의 다른 글
[pytorch] Nan, inf알아보기 (0) | 2024.02.06 |
---|---|
[pytorch] load state dict 후 업데이트 (0) | 2024.01.22 |
torch stack, cat 차이 (0) | 2024.01.17 |
함수 이용해 이미지 rotate 하기 (0) | 2024.01.17 |
torch detach 실험해보기 (0) | 2024.01.12 |