본문 바로가기
[AI] - Machine Learning

# 1. 교차 검증 (K-폴드 교차 검증)

by Bebsae 2021. 4. 1.

교차 검증은 오버피팅(Overfitting)을 방지하기 위한 프로세스 중 하나이다. 오버피팅이 발생하는 이유는 모델의 학습이 Training-Set에만 너무 의존되어 있기 때문에 일반화가 잘 이루어지지 않아 다른 데이터가 들어오면 성능이 떨어진다. 그래서 Training-Set을 Training-Set과 Validation-Set으로 세분화하는 것이 교차 검증이다.

 

이를 세분화해서 무엇을 하느냐? Training-Set으로 학습된 모델을 Test-Set으로 평가하기 전에 Validation-Set으로 평가하는 것이다. 즉, 모의고사를 본다고 생각하면 된다.

 

우선 K-폴드 교차 검증부터 확인해보겠다.

Training-Set을 K등분한다. 예를들어, 100개의 Training-Set이 있다면..

K = 1(1~20), 2(21~40), 3(41~60), 4(61~80), 5(81~100) 등분이 된다.

그리고, K = 1~5를 Iterative하게 Validation-Set으로써 평가를 한다. 총 5번의 평가를 하게 된다. 그리고, 이 5개의 평가의 평균을 K-폴드 교차 검증의 결과로 반영한다.

 

import numpy as np

from sklearn.datasets import load_iris
from sklearn.model_selection import KFold
from sklearn.tree import DecisionTreeClassifier

iris = load_iris()
features = iris.data
label = iris.target
dt_clf = DecisionTreeClassifier(random_state=156)

k_fold = KFold(n_splits=5)
cv_accuracy = []
print(features.shape)  # (150, 4)

n_iter = 0
for train_index, test_index in k_fold.split(features):
    print(train_index, test_index)

위 코드를 보자. k_fold 인스턴스의 인자로 n_splits=5가 들어갔다. 즉, 5등분을 하겠다는 의미이다. 붓꽃 데이터 세트의 갯수가 150개임을 위에서 확인했다. 이 데이터 세트를 kfold를 통해 세분화한 인덱스를 살펴보자

 

(150, 4)

# K=1
# Train-Set
[ 30  31  32  33  34  35  36  37  38  39  40  41  42  43  44  45  46  47
  48  49  50  51  52  53  54  55  56  57  58  59  60  61  62  63  64  65
  66  67  68  69  70  71  72  73  74  75  76  77  78  79  80  81  82  83
  84  85  86  87  88  89  90  91  92  93  94  95  96  97  98  99 100 101
 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
 138 139 140 141 142 143 144 145 146 147 148 149] 
 # Test-Set
 [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29]
 
 # K=2
 # Train-Set
[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  26  27  28  29  60  61  62  63  64  65
  66  67  68  69  70  71  72  73  74  75  76  77  78  79  80  81  82  83
  84  85  86  87  88  89  90  91  92  93  94  95  96  97  98  99 100 101
 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
 138 139 140 141 142 143 144 145 146 147 148 149] 
 # Test-Set
 [30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
 54 55 56 57 58 59]
 
 # K=3
 # Train-Set
[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35
  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53
  54  55  56  57  58  59  90  91  92  93  94  95  96  97  98  99 100 101
 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
 138 139 140 141 142 143 144 145 146 147 148 149] 
 # Test-Set
 [60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
 84 85 86 87 88 89]
 
 # K=4
 # Train-Set
[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35
  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53
  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71
  72  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89
 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
 138 139 140 141 142 143 144 145 146 147 148 149] 
 # Test-Set
 [ 90  91  92  93  94  95  96  97  98  99 100 101 102 103 104 105 106 107
 108 109 110 111 112 113 114 115 116 117 118 119]
 
 # K=4
 # Train-Set
[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35
  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53
  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71
  72  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89
  90  91  92  93  94  95  96  97  98  99 100 101 102 103 104 105 106 107
 108 109 110 111 112 113 114 115 116 117 118 119] 
 # Test-Set
 [120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
 138 139 140 141 142 143 144 145 146 147 148 149]

 

다음으로는 이 인덱스를 활용하여 예측을 해보겠다.

for train_index, test_index in k_fold.split(features):
    # train : 학습, test : 검증
    X_train, X_test = features[train_index], features[test_index]
    y_train, y_test = label[train_index], label[test_index]

    dt_clf.fit(X_train, y_train)
    pred = dt_clf.predict(X_test)
    n_iter += 1
    print(pred)
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]  # k=1 (0~29) 예측
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 2 1 1]  # k=2 (29~59) 예측
[1 1 1 1 1 1 1 1 1 1 2 1 2 1 1 1 1 2 1 1 1 1 1 2 1 1 1 1 1 1]  # k=3 (60~89) 예측
[1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 1]  # k=4 (90~119) 예측
[2 2 2 1 2 2 1 1 2 1 2 2 2 1 1 2 2 2 1 2 2 2 2 2 2 2 1 2 2 2]  # k=5 (120~149) 예측

 

이어서 정확도를 확인해보겠다.

for train_index, test_index in k_fold.split(features):
    # train : 학습, test : 검증
    X_train, X_test = features[train_index], features[test_index]
    y_train, y_test = label[train_index], label[test_index]

    dt_clf.fit(X_train, y_train)
    pred = dt_clf.predict(X_test)
    n_iter += 1

    accuracy = np.round(accuracy_score(y_test, pred), 4)
    cv_accuracy.append(accuracy)

print(cv_accuracy)
print(np.mean(cv_accuracy))
[1.0, 0.9667, 0.8667, 0.9333, 0.7333]  # K당 각 정확도
0.9  # 평균 (최종 결과)

댓글