Hello Computer Vision

CIFAR10-C 데이터셋 다루기 본문

딥러닝/파이토치

CIFAR10-C 데이터셋 다루기

지웅쓰 2024. 1. 20. 15:25

이번에 실험하면서 CIFAR10-C 데이터셋을 다루게되었는데 시행착오도 많아가지고 공유해보려고 한다.

 

우선 https://zenodo.org/records/2535967 사이트에서 CIFAR10-C 데이터 다운로드 링크를 복사 후 터미널에

wget {dataest link} 를 하면된다. colab환경이라면 !wget하면 된다

 

그렇다면 tar를 확장자로 하는 압축파일을 다운로드 받게 되는데 압축을 푸는데 다양한 방법이 있지만

import tarfile

tar = tarfile.open('/content/CIFAR-10-C.tar?download=1', 'r')
tar.extractall()
tar.close()

이렇게 한다면 압축 파일을 해제할 수 있다(위 환경은 colab환경이다)

 

그렇다면

이렇게 많은 corruption들이 들어간 파일을 얻게된다. npy파일은 numpy를 이용한 파일이고 이를 load해본다면,

for name in corruption_list:
  file_ = os.path.join('/content/CIFAR-10-C',name+ '.npy')
  data_array = np.load(file_)
  print(data_array.shape)
  
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)
(50000, 32, 32, 3)

이렇게 얻을 수 있다. 물론 파일 중 label도 있다. 여기서 알 수 있는 점은 각각의 증강마다 50,000개의 이미지가 있다는 것이다. 즉, 이 CIFAR10-C를 100%활용하기 위해서는 증강마다 정확도를 따로 기록해야한다는 것이다(근데 나는 바보같이 한개의 증강만을 활용했다...).

 

https://github.com/tanimutomo/cifar10-c-eval/blob/master/src/test.py

위 github에서 이를 잘 활용하고 있으니 참고하면 좋다.

 

그렇다면 위 github에서 dataset을 어떻게 정의했는지 보자

class CIFAR10C(datasets.VisionDataset):
    def __init__(self, root :str, name :str,
                 transform=None, target_transform=None):
        assert name in corruptions
        super(CIFAR10C, self).__init__(
            root, transform=transform,
            target_transform=target_transform
        )
        data_path = os.path.join(root, name + '.npy')
        target_path = os.path.join(root, 'labels.npy')
        
        self.data = np.load(data_path)
        self.targets = np.load(target_path)
        
    def __getitem__(self, index):
        img, targets = self.data[index], self.targets[index]
        img = Image.fromarray(img)
        
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            targets = self.target_transform(targets)
            
        return img, targets
    
    def __len__(self):
        return len(self.data)

root는 기본 root, name은 증강을 말하며, npy 파일에 대하여 load함을 알 수 있다.

 

추가적으로 cifar10, cifar10-c 데이터셋에 대하여 각각의 index의 이미지들이 같을까? 이거에 대해서는 시각화 해봤을 때 다르다. 그리고 보통 train할 때 shuffle하는 경우가 다르므로 label값을 따로 불러내어 (labels.npy파일에 있다) 정확도를 측정해야한다.

 

++2024.02.20 수정

만약 1개의 corrupted dataset 을 만들 때면 위와 같이 수행하면 된다. 그러나 어떤 논문들에서는 다수의 corrupted이미지를 넣어 구성하는 경우도 있다.

class CIFAR10_C_all(datasets.VisionDataset):
    def __init__(self, root: str, corruption_list, indexs, transform=None, target_transform=None):
        super(CIFAR10_C_all, self).__init__(root, transform=transform, target_transform=target_transform)
        data_stack = np.empty([0, 32, 32, 3])
        target_stack = np.empty([0])
        self.indexs = indexs
        for i, name in enumerate(corruption_list):
            data_path = os.path.join(root, name + '.npy')
            target_path = os.path.join(root, 'labels.npy')
            indexs = np.random.choice(self.indexs, 2300, False)
            data = np.load(data_path)
            data = data[indexs]
            target = np.load(target_path)
            target = target[indexs]

            data_stack = np.append(data_stack, data, axis=0)
            target_stack = np.append(target_stack, target, axis=0)

        self.data = data_stack.astype(np.uint8)
        self.target = target_stack.astype(np.int64)

    def __getitem__(self, index):
        img, target = self.data[index], self.target[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.data)

위와 같이 구성하고 인자로 원하는 corruption list를 넣으면 된다. 코드 설명은 간단하니 따로하지 않으려고 한다.

'딥러닝 > 파이토치' 카테고리의 다른 글

[pytorch] Nan, inf알아보기  (0) 2024.02.06
[pytorch] load state dict 후 업데이트  (0) 2024.01.22
torch stack, cat 차이  (0) 2024.01.17
함수 이용해 이미지 rotate 하기  (0) 2024.01.17
torch detach 실험해보기  (0) 2024.01.12