GridSearchCV

2021. 4. 12. 07:54·[AI] - Machine Learning

GridSearchCV API는 교차검증을 기반으로 하이퍼 파라미터의 최적값을 찾는데 도움을 준다. (Grid는 격자라는 뜻으로, 촘촘하게 파라미터를 입력하면서 테스트를 하는 방식)

 

예를 들어, 결정 트리 알고리즘의 파라미터 조합을 찾는다고 가정해보자. 그럼 다음과 같은 파라미터가 필요할 것이다.

grid_parameters = {'max_depth': [1, 2, 3],
                   'min_samples_split': [2, 3]}

 

최적의 파라미터 (max_depth, min_samples_split)의 값을 찾기 위해서는 for문을 사용하여 6번(3 * 2)에 걸쳐 하이퍼 파라미터를 변경하면서 교차 검증 데이터 세트에 수행 성능을 측정한다. 여기에 CV가 3회라면 3개의 폴딩 세트가 존재하므로 총 18번(6 *3)회의 학습 및 평가가 이루어진다.  

 

GridSearchCV 클래스의 생성자로 들어가는 파라미터는 다음과 같다.

- estimator : classifier, regressor, pipeline

- param_grid : key + list 값을 가지는 딕셔너리가 주어진다. 이는 estimator의 튜닝을 위한 파라미터명과 값이다.

- scoring : 예측 성능을 측정할 평가 방법. (e.g - 'accuracy')

- cv : 교차 검증을 위해 분할되는 학습/테스트 세트의 개수 (폴딩수)

- refit : default는 True, True일 경우 최적의 하이퍼 파라미터를 찾은 뒤 estimator 객체를 해당 하이퍼 파라미터로 학습시킨다.

 

import pandas as pd
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV

pd.set_option('display.max_columns', None)

iris = load_iris()
data = iris.data
label = iris.target

X_train, X_test, y_train, y_test = train_test_split(data, label, test_size=0.2, random_state=121)

dt_clf = DecisionTreeClassifier()

hyper_params = {'max_depth': [1, 2, 3],
                'min_samples_split': [2, 3]}

grid_dt_clf = GridSearchCV(estimator=dt_clf, param_grid=hyper_params, scoring='accuracy', cv=3)
grid_dt_clf.fit(X_train, y_train)

scores_df = pd.DataFrame(grid_dt_clf.cv_results_)
# print(scores_df.columns)
# print(grid_dt_clf.cv_results_)
# print(scores_df)
print(scores_df[['params', 'mean_test_score', 'rank_test_score', 'split0_test_score', 'split1_test_score', 'split2_test_score']])

>>

                                     params  mean_test_score  rank_test_score  \
0  {'max_depth': 1, 'min_samples_split': 2}         0.700000                5   
1  {'max_depth': 1, 'min_samples_split': 3}         0.700000                5   
2  {'max_depth': 2, 'min_samples_split': 2}         0.958333                3   
3  {'max_depth': 2, 'min_samples_split': 3}         0.958333                3   
4  {'max_depth': 3, 'min_samples_split': 2}         0.975000                1   
5  {'max_depth': 3, 'min_samples_split': 3}         0.975000                1   

   split0_test_score  split1_test_score  split2_test_score  
0              0.700                0.7               0.70  
1              0.700                0.7               0.70  
2              0.925                1.0               0.95  
3              0.925                1.0               0.95  
4              0.975                1.0               0.95  
5              0.975                1.0               0.95  

위의 결과를 보면 하이퍼 파라미터의 조합이 6가지이므로 6번 iterate하게 동작했다. rank_test_score는 예측 성능 순위를 의미한다. 즉, 1이면 1위임을 의미한다. split0~split2는 각 폴딩 세트에서 테스트한 성능 수치이다. mean은 이 수치들의 평균이다.

'[AI] - Machine Learning' 카테고리의 다른 글

Feature Scaling and Normalization (StandardScaler, MinMaxScaler)  (0) 2021.04.16
Data Preprocessing (Label Encoder, One-Hot Encoder)  (0) 2021.04.13
Stratified K 폴드  (0) 2021.04.01
교차 검증 (K-폴드 교차 검증)  (0) 2021.04.01
Scikit-learn의 주요모듈  (0) 2021.04.01
'[AI] - Machine Learning' 카테고리의 다른 글
  • Feature Scaling and Normalization (StandardScaler, MinMaxScaler)
  • Data Preprocessing (Label Encoder, One-Hot Encoder)
  • Stratified K 폴드
  • 교차 검증 (K-폴드 교차 검증)
Bebsae
Bebsae
  • Bebsae
    뱁새zip
    Bebsae
  • 전체
    오늘
    어제
    • 분류 전체보기 (108)
      • [DevOps] - Kubernetes (5)
      • [DevOps] - AWS (1)
      • [AI] - Machine Learning (19)
      • [AI] - Neural Network (7)
      • [CS] - Network (2)
      • [CS] - Data Structure (3)
      • [CS] - Design Pattern (6)
      • [Language] - Python (15)
      • [Library] - Numpy (7)
        • Quick Start (5)
        • API (2)
      • [Framework] - Django (3)
      • [Framework] - QGIS (6)
      • [Framework] - PyQT (4)
      • [Mathematics] - Linear Alge.. (14)
      • [Mathematics] - Statistical (2)
      • [ETC] - Python (3)
      • [ETC] - C++ (1)
      • [ETC] - Linux (1)
      • 논문 (5)
      • 회고록 (3)
      • 생산성 (1)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

    • 깃허브
  • 공지사항

  • 인기 글

  • 태그

    linearalgebra
    RNN
    algebra
    교차검증
    Python
    선형대수
    MachineLearning
    파이썬
    Convolution
    numpy
    머신러닝
    신경망
    Learning
    decomposition
    Linear
    디자인패턴
    Machine
    QGIS
    분해
    DEEPLEARNING
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
Bebsae
GridSearchCV
상단으로

티스토리툴바