본문 바로가기
[AI] - Machine Learning

# 2. Stratified K 폴드

by Bebsae 2021. 4. 1.

이전 포스트에서 K 폴드에 대해서 알아봤다. K 폴드의 단점은 치중되어 있는 레이블 데이터 집합을 고려하지 못한 것이다. 이게 무슨 소리인지 극단적인 예시를 들어보겠다.

 

10 폴드인 상태에서 1억개의 레이블 데이터 세트(분류는 O와 X)가 있고 그중 1번째부터 10번째 데이터는 O, 나머지 9999만 9990개는 X이라고 가정하자.

첫 번째 폴드를 검증 데이터 세트로 사용했을 때 1천만개의 데이터중 10개의 O가 있다.

하지만, 두 번째 폴드부터는 검증 데이터 세트로 사용할 때 전부 X 밖에 없을 것이다.

 

이러한 상태를 imbalanced 하다고 표현한다. imbalanced 상태를 위해 적은 레이블을 고르게 분포한 뒤에 K 폴드를 적용한다. 

 

import pandas as pd
import numpy as np

from sklearn.datasets import load_iris
from sklearn.model_selection import KFold
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

iris = load_iris()
iris_df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
iris_df['label'] = iris.target
print(iris_df['label'].value_counts())
print(iris_df['label'])

붓꽃 데이터 세트를 imblanced 예시로 보자

0    50
1    50
2    50

Name: label, dtype: int64
0      0
1      0
2      0
3      0
4      0
      ..
145    2
146    2
147    2
148    2
149    2

0, 1, 2번 클래스의 갯수가 각각 50개씩 있으며 실제로 0~49 = 0 / 50~99 = 1 / 100~149 = 2 인 레이블 데이터 세트로 구성되어 있다.

 

k_fold = KFold(n_splits=3)
for train_index, test_index in k_fold.split(iris_df):
    print(train_index, test_index)
    print(iris_df['label'].iloc[train_index])
    print(iris_df['label'].iloc[test_index])
# K=1
[ 50  51  52  53  54  55  56  57  58  59  60  61  62  63  64  65  66  67
  68  69  70  71  72  73  74  75  76  77  78  79  80  81  82  83  84  85
  86  87  88  89  90  91  92  93  94  95  96  97  98  99 100 101 102 103
 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
 140 141 142 143 144 145 146 147 148 149] 
 [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49]
# train_index
50     1
51     1
52     1
53     1
54     1
      ..
145    2
146    2
147    2
148    2
149    2
# test_index
0     0
1     0
2     0
3     0
4     0
5     0
6     0
7     0
8     0
9     0
10    0
11    0
12    0
13    0
14    0
15    0
16    0
17    0
18    0
19    0
20    0
21    0
22    0
23    0
24    0
25    0
26    0
27    0
28    0
29    0
30    0
31    0
32    0
33    0
34    0
35    0
36    0
37    0
38    0
39    0
40    0
41    0
42    0
43    0
44    0
45    0
46    0
47    0
48    0
49    0

# k=2
[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35
  36  37  38  39  40  41  42  43  44  45  46  47  48  49 100 101 102 103
 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
 140 141 142 143 144 145 146 147 148 149] 
 [50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
 98 99]
# train_index
0      0
1      0
2      0
3      0
4      0
      ..
145    2
146    2
147    2
148    2
149    2
# test_index
50    1
51    1
52    1
53    1
54    1
55    1
56    1
57    1
58    1
59    1
60    1
61    1
62    1
63    1
64    1
65    1
66    1
67    1
68    1
69    1
70    1
71    1
72    1
73    1
74    1
75    1
76    1
77    1
78    1
79    1
80    1
81    1
82    1
83    1
84    1
85    1
86    1
87    1
88    1
89    1
90    1
91    1
92    1
93    1
94    1
95    1
96    1
97    1
98    1
99    1
 
 # k=3
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
 96 97 98 99] 
 [100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
 136 137 138 139 140 141 142 143 144 145 146 147 148 149]
# train_index
0     0
1     0
2     0
3     0
4     0
     ..
95    1
96    1
97    1
98    1
99    1
# test_index
100    2
101    2
102    2
103    2
104    2
105    2
106    2
107    2
108    2
109    2
110    2
111    2
112    2
113    2
114    2
115    2
116    2
117    2
118    2
119    2
120    2
121    2
122    2
123    2
124    2
125    2
126    2
127    2
128    2
129    2
130    2
131    2
132    2
133    2
134    2
135    2
136    2
137    2
138    2
139    2
140    2
141    2
142    2
143    2
144    2
145    2
146    2
147    2
148    2
149    2

K=1일 때는 검증 데이터세트에 0, 학습 데이터세트에는 1, 2

K=2일 때는 검증 데이터세트에 1, 학습 데이터세트에는 0, 2

K=3일 때는 검증 데이터세트에 2, 학습 데이터세트에는 0, 1

분포가 고르지 않은 모습을 보여준다.

이런 방식으로 학습을 하고 검증을 할 경우 정확도는 0이 될 수 밖에 없다. 

 

StratifiedKFold를 사용하면 이러한 문제는 해결할 수 있다.

k_fold = StratifiedKFold(n_splits=3)
for train_index, test_index in k_fold.split(iris_df, iris_df['label']):
    print(train_index, test_index)
    print(iris_df['label'].iloc[train_index])
    print(iris_df['label'].iloc[test_index])
# k=1
[ 17  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34
  35  36  37  38  39  40  41  42  43  44  45  46  47  48  49  67  68  69
  70  71  72  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87
  88  89  90  91  92  93  94  95  96  97  98  99 116 117 118 119 120 121
 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
 140 141 142 143 144 145 146 147 148 149] 
 [  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  50
  51  52  53  54  55  56  57  58  59  60  61  62  63  64  65  66 100 101
 102 103 104 105 106 107 108 109 110 111 112 113 114 115]
# train_index
17     0
18     0
19     0
20     0
21     0
      ..
145    2
146    2
147    2
148    2
149    2
# test_index
0      0
1      0
2      0
3      0
4      0
5      0
6      0
7      0
8      0
9      0
10     0
11     0
12     0
13     0
14     0
15     0
16     0
50     1
51     1
52     1
53     1
54     1
55     1
56     1
57     1
58     1
59     1
60     1
61     1
62     1
63     1
64     1
65     1
66     1
100    2
101    2
102    2
103    2
104    2
105    2
106    2
107    2
108    2
109    2
110    2
111    2
112    2
113    2
114    2
115    2

# k=2
[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  34
  35  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52
  53  54  55  56  57  58  59  60  61  62  63  64  65  66  83  84  85  86
  87  88  89  90  91  92  93  94  95  96  97  98  99 100 101 102 103 104
 105 106 107 108 109 110 111 112 113 114 115 133 134 135 136 137 138 139
 140 141 142 143 144 145 146 147 148 149] 
 [ 17  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  67
  68  69  70  71  72  73  74  75  76  77  78  79  80  81  82 116 117 118
 119 120 121 122 123 124 125 126 127 128 129 130 131 132]
# train_set
0      0
1      0
2      0
3      0
4      0
      ..
145    2
146    2
147    2
148    2
149    2
# test_set
17     0
18     0
19     0
20     0
21     0
22     0
23     0
24     0
25     0
26     0
27     0
28     0
29     0
30     0
31     0
32     0
33     0
67     1
68     1
69     1
70     1
71     1
72     1
73     1
74     1
75     1
76     1
77     1
78     1
79     1
80     1
81     1
82     1
116    2
117    2
118    2
119    2
120    2
121    2
122    2
123    2
124    2
125    2
126    2
127    2
128    2
129    2
130    2
131    2
132    2

# k=3
[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  50  51
  52  53  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69
  70  71  72  73  74  75  76  77  78  79  80  81  82 100 101 102 103 104
 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
 123 124 125 126 127 128 129 130 131 132] 
 [ 34  35  36  37  38  39  40  41  42  43  44  45  46  47  48  49  83  84
  85  86  87  88  89  90  91  92  93  94  95  96  97  98  99 133 134 135
 136 137 138 139 140 141 142 143 144 145 146 147 148 149]
# train_index
0      0
1      0
2      0
3      0
4      0
      ..
128    2
129    2
130    2
131    2
132    2
# test_index
34     0
35     0
36     0
37     0
38     0
39     0
40     0
41     0
42     0
43     0
44     0
45     0
46     0
47     0
48     0
49     0
83     1
84     1
85     1
86     1
87     1
88     1
89     1
90     1
91     1
92     1
93     1
94     1
95     1
96     1
97     1
98     1
99     1
133    2
134    2
135    2
136    2
137    2
138    2
139    2
140    2
141    2
142    2
143    2
144    2
145    2
146    2
147    2
148    2
149    2

위와 같이 학습, 검증 데이터 세트에 고르게 클래스가 분포한 것을 볼 수 있다.

 

import pandas as pd
import numpy as np

from sklearn.datasets import load_iris
from sklearn.model_selection import KFold
from sklearn.model_selection import StratifiedKFold
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

iris = load_iris()
features = iris.data
label = iris.target
print(features, label)

stk_fold = StratifiedKFold(n_splits=3)
dt_clf = DecisionTreeClassifier(random_state=156)
cv_accuracy = list()

for train_index, test_index in stk_fold.split(features, label):
    X_train, X_test = features[train_index], features[test_index]
    y_train, y_test = label[train_index], label[test_index]

    dt_clf.fit(X_train, y_train)
    pred = dt_clf.predict(X_test)
    accuracy = np.round(accuracy_score(y_test, pred), 4)
    cv_accuracy.append(accuracy)

print(np.mean(cv_accuracy))
0.9666666666666667

여기서 중요한 점은 K 폴드와 다르게 split()의 인자에 학습 데이터(features)만 들어가는 것이 아니라 정답(label)도 들어가야 한다. 왜냐하면 label의 분포에 따라 고르게 분포해야 하기 때문이다.

Stratified K Fold를 사용하여 분포 시킨 후, 각 폴드별 평가를 cv_accuracy에 append 한 뒤 평균을 내어 최종적인 정확도는 97퍼센트에 임박했다.

댓글