Hello Computer Vision

[머신러닝]KFold, StratifiedKFold에 대한 이해 본문

머신러닝

[머신러닝]KFold, StratifiedKFold에 대한 이해

지웅쓰 2023. 1. 28. 16:47

딥러닝 모델을 훈련시킬 때는 한번도 사용해본적은 없지만 머신러닝에서 자주 쓰이는 교차검증 방법들을 한번 정리해보려고 한다.

잘 정리해놓으면 나중에 딥러닝 모델에서도 충분히 활용할 수 있겠지?

 

KFold란?

데이터가 많으면 문제 없겠지만 데이터가 한정된 상황에서 Test 데이터셋을 어떻게 설정하냐에 따라 정확도가 달라질 수 있다.

이러한 문제와 불확실성을 해결하기 위해 데이터셋을 k개로 쪼개 모든 데이터가 검증 과정을 거치도록 분할하는 것이다.

이를 잘 나타내는 이미지는 다음과 같다. 이미지출처

장점은 알고리즘의 정량적인 성능을 평가할 수 있지만 k번 훈련,검증 과정을 거쳐야하므로 한번 훈련할 때보다 훈련시간이 증가한다. 

 

파이썬 코드

from sklearn.datasets import load_iris
from sklearn.model_selection import KFold

dataset = load_iris()
x = dataset.data

kfold = KFold(n_splits = 3, shuffle = True)

for i, (train_index, test_index) in enumerate(kfold.split(x)):
    print('{}번째 Fold'.format(i+1))
    print(train_index)
    print(test_index)
1번째 Fold
[  0   1   3   4   6   8   9  10  11  12  13  14  17  18  20  22  25  28
  29  30  31  32  33  34  35  38  39  40  43  45  47  48  50  51  52  54
  57  58  61  62  64  65  66  67  68  69  70  72  73  74  76  78  79  80
  81  83  84  86  87  89  90  91  93  94  95  96  97  98  99 101 102 105
 106 107 108 109 110 113 115 117 118 119 121 122 124 126 128 130 131 132
 135 136 137 138 140 142 143 146 148 149]
[  2   5   7  15  16  19  21  23  24  26  27  36  37  41  42  44  46  49
  53  55  56  59  60  63  71  75  77  82  85  88  92 100 103 104 111 112
 114 116 120 123 125 127 129 133 134 139 141 144 145 147]
2번째 Fold
[  0   2   3   4   5   7   8   9  12  14  15  16  17  19  21  23  24  25
  26  27  29  30  34  35  36  37  38  39  41  42  43  44  46  49  50  51
  53  54  55  56  57  58  59  60  61  63  64  65  66  69  70  71  72  75
  77  80  82  83  84  85  87  88  90  92  94  95  98 100 102 103 104 105
 108 109 111 112 113 114 116 117 118 120 121 123 125 127 129 132 133 134
 135 136 138 139 140 141 142 144 145 147]
[  1   6  10  11  13  18  20  22  28  31  32  33  40  45  47  48  52  62
  67  68  73  74  76  78  79  81  86  89  91  93  96  97  99 101 106 107
 110 115 119 122 124 126 128 130 131 137 143 146 148 149]
3번째 Fold
[  1   2   5   6   7  10  11  13  15  16  18  19  20  21  22  23  24  26
  27  28  31  32  33  36  37  40  41  42  44  45  46  47  48  49  52  53
  55  56  59  60  62  63  67  68  71  73  74  75  76  77  78  79  81  82
  85  86  88  89  91  92  93  96  97  99 100 101 103 104 106 107 110 111
 112 114 115 116 119 120 122 123 124 125 126 127 128 129 130 131 133 134
 137 139 141 143 144 145 146 147 148 149]
[  0   3   4   8   9  12  14  17  25  29  30  34  35  38  39  43  50  51
  54  57  58  61  64  65  66  69  70  72  80  83  84  87  90  94  95  98
 102 105 108 109 113 117 118 121 132 135 136 138 140 142]

이렇게  전체데이터셋을 3번 반복함을 알 수 있다.

KFold 의 주요 인자인 n_splits 의 기본값은 5이며 shuffle은 False이다. 위 코드는 shuffle값을 True를 준 결과인데 False를 준 결과는 다음과 같다.

1번째 Fold
[ 50  51  52  53  54  55  56  57  58  59  60  61  62  63  64  65  66  67
  68  69  70  71  72  73  74  75  76  77  78  79  80  81  82  83  84  85
  86  87  88  89  90  91  92  93  94  95  96  97  98  99 100 101 102 103
 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
 140 141 142 143 144 145 146 147 148 149]
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49]
2번째 Fold
[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35
  36  37  38  39  40  41  42  43  44  45  46  47  48  49 100 101 102 103
 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
 140 141 142 143 144 145 146 147 148 149]
[50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
 98 99]
3번째 Fold
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
 96 97 98 99]
[100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
 136 137 138 139 140 141 142 143 144 145 146 147 148 149]

검증 데이터셋이 순차적으로 할당됨을 알 수 있다.

 

StratifiedKFold란?

stratified 의 뜻은 사전적 의미로 '층상의'를 가르킨다.

가장 큰 특징이라면 KFold에서 나눈 각각의 데이터셋에서 각각의 클래스가 동일한 비중으로 나누는 것이다.

KFold에서 나눈 데이터셋이 한쪽 클래스에 몰려있을 가능성이 있으니 StratifiedKFold를 사용한다.

파이썬 코드

from sklearn.datasets import load_iris
from sklearn.model_selection import StratifiedKFold

dataset = load_iris()
x = dataset.data
y = dataset.target

str_kfold = StratifiedKFold(n_splits = 3, shuffle = False)

for i, (train_index, test_index) in enumerate(str_kfold.split(x, y)):
    print('{}번째 Fold'.format(i+1))
    print(train_index)
    print(test_index)

 

1번째 Fold
[ 17  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34
  35  36  37  38  39  40  41  42  43  44  45  46  47  48  49  67  68  69
  70  71  72  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87
  88  89  90  91  92  93  94  95  96  97  98  99 116 117 118 119 120 121
 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
 140 141 142 143 144 145 146 147 148 149]
[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  50
  51  52  53  54  55  56  57  58  59  60  61  62  63  64  65  66 100 101
 102 103 104 105 106 107 108 109 110 111 112 113 114 115]
2번째 Fold
[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  34
  35  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52
  53  54  55  56  57  58  59  60  61  62  63  64  65  66  83  84  85  86
  87  88  89  90  91  92  93  94  95  96  97  98  99 100 101 102 103 104
 105 106 107 108 109 110 111 112 113 114 115 133 134 135 136 137 138 139
 140 141 142 143 144 145 146 147 148 149]
[ 17  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  67
  68  69  70  71  72  73  74  75  76  77  78  79  80  81  82 116 117 118
 119 120 121 122 123 124 125 126 127 128 129 130 131 132]
3번째 Fold
[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  50  51
  52  53  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69
  70  71  72  73  74  75  76  77  78  79  80  81  82 100 101 102 103 104
 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
 123 124 125 126 127 128 129 130 131 132]
[ 34  35  36  37  38  39  40  41  42  43  44  45  46  47  48  49  83  84
  85  86  87  88  89  90  91  92  93  94  95  96  97  98  99 133 134 135
 136 137 138 139 140 141 142 143 144 145 146 147 148 149]

shuffle값을 False로 주었음에도 첫번째 Fold를 살펴보자면 순차적으로 선별되지 않음을 알 수 있다. 이는 클래스 비중을 맞추기 위해서이다.