본문 바로가기
[독파하기] 핸즈온 머신러닝/12장 - 텐서플로를 사용한 사용자 정의 모델과 훈련

12.3 사용자 정의 모델과 훈련 알고리즘 - (1)

by Bebsae 2022. 3. 26.

Prerequisite

from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

housing = fetch_california_housing()
X_train_full, X_test, y_train_full, y_test = train_test_split(housing.data, housing.target.reshape(-1, 1), random_state=42)
X_train, X_valid, y_train, y_valid = train_test_split(X_train_full, y_train_full, random_state=42)

scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_valid_scaled = scaler.transform(X_valid)
X_test_scaled = scaler.transform(X_test)

데이터를 불러온 후에 StandardScaler를 통해 데이터를 가우시안 분포를 따르도록 스케일링을 취한다. StandardScaler에 대한 자세한 내용은 해당 포스트를 참조한다.

 

import tensorflow.keras as keras

input_shape = X_train.shape[1:]

model = keras.models.Sequential([
    keras.layers.Dense(30, activation='selu', kernel_initializer='lecun_normal', input_shape=input_shape),
    keras.layers.Dense(1)                       
])

위와 같은 임의의 MLP가 존재한다고 가정한다.

 

사용자 정의 손실 함수

def huber_fn(y_true, y_pred):
    error = y_true - y_pred
    is_small_error = tf.abs(error) < 1
    squared_loss = tf.square(error) / 2
    linear_loss = tf.abs(error) - 0.5
    return tf.where(is_small_error, squared_loss, linear_loss)

데이터를 학습하는 과정에서 손실 함수를 결정할 때, MSE(Mean Sqaured Error)를 적용하는 경우 오차에 대한 제곱의 패널티가 적용되므로 원치 않는 학습 결과가 나올 수 있다. 그렇다고 MAE(Mean Absolute Error)를 적용하는 경우 예상보다 에러에 너무 둔감한 경우가 발생할 수 있다. 이러한 경우, 두 손실 함수의 중간에 해당하는 함수를 구현하고 싶을 것이다. 위 코드는 적합한 손실 함수를 직접 정의하는 과정이다.

 

model.compile(loss=huber_fn, optimizer='nadam', metrics=['mae'])

model.fit(X_train_scaled, y_train, epochs=2, validation_data=(X_valid_scaled, y_valid))
model.save('my_model_with_a_custom_loss.h5')

모델 컴파일 시 위에서 정의한 사용자 정의 함수를 loss 파라미터에 전달하면 된다. 하지만, 이렇게 저장한 모델을 불러오는 과정은 조금 복잡하다.

 

사용자 정의 요소를 가진 모델을 저장하고 로드하기

model = keras.models.load_model('my_model_with_a_custom_loss.h5', custom_objects={'huber_fn': huber_fn})

모델을 로드할 때 단순히 저장한 모델명 뿐만 아니라, custom_objects 파라미터에 사용자 정의 요소 이름사용자 정의 요소 객체를 딕셔너리 형태로 매핑해서 전달한다.

 

def create_huber(threshold=1.0):
    def huber_fn(y_true, y_pred):
        error = y_true - y_pred
        is_small_error = tf.abs(error) < threshold
        squared_loss = tf.square(error) / 2
        linear_loss = threshold * (tf.abs(error) - threshold/2)
        return tf.where(is_small_error, squared_loss, linear_loss)
    return huber_fn
    
model.compile(loss=create_huber(2.0), optimizer='nadam', metrics=['mae'])

model.save('my_model_a_custom_loss_threshold_2.h5')

model = keras.models.load_model('my_model_with_a_custom_loss_threshold_2.h5', custom_objects={'huber_fn': create_huber(2.0)})

다른 방법으로 후버 손실 함수를 구현한 예제를 보자. 앞서 구현한 후버 손실 함수는 임계치가 1을 기준으로 구현되어 있다. 이 임계치를 다른 값을 주고 싶은 경우, 위 코드처럼 클로저(Closure)를 반환하는 형태로 구현하면 된다. 모델 컴파일이나 로드할 시, create_huber 함수를 호출한 반환값은 huber_fn 함수(정확히는 클로저)이므로, custom_objects의 키도 huber_fn이 된다. 하지만, 매번 일일이 임계치를 2.0으로 하드코딩하기를 원치 않을 것이다. 그렇다면 다음 코드를 보자.

 

class HuberLoss(keras.losses.Loss):
    def __init__(self, threshold=1.0, **kwargs):
        super().__init__(**kwargs)
        self.threshold = threshold
        
    def call(self, y_true, y_pred):
        error = y_true - y_pred
        is_small_error = tf.abs(error) < self.threshold
        squared_error = tf.square(error) / 2
        linear_loss = self.threshold * (tf.abs(error) - self.threshold/2)
        return tf.where(is_small_error, squared_error, linear_loss)
    
    def get_config(self):
        base_config = super().get_config()
        return {**base_config, "threshold": self.threshold}
        
model.compile(loss=HuberLoss(2.), optimizer='nadam', metrics=['mae'])

model.save('my_model_with_a_custom_loss_class.h5')

model = keras.models.load_model('my_model_with_a_custom_loss_class.h5', custom_objects={'HuberLoss': HuberLoss})
model.loss.threshold # 2.0

keras.losses.Loss 클래스를 상속하여, get_config 메소드를 오버라이딩 하면 이러한 문제를 해결할 수 있다. 추측이지만, 모델을 저장할 때에 사용자 정의 손실 함수 객체의 정보를 함께 저장하여, 모델을 로드할 때에 저장된 정보(위 코드에서는 self.threshold)를 불러와 기본 정보(base_config)외에도 같이 리턴하는 듯 하다.

댓글