본문 바로가기
[AI] - Neural Network

# 6. Long Short-Term Memory (LSTM) - Code

by Bebsae 2022. 1. 3.

지난 포스트에서는 LSTM의 이론에 대해 다루었다. 이번 포스트에서는 LSTM을 코드로 구현해보는데 주의해야할 점이 있다. 기존의 RNN을 구현할 때에는 은닉 상태에 해당하는 변수 hidden 하나만 다음 시퀀스의 메모리 셀로 전달하는 구조였다. 하지만, LSTM은 은닉 상태 이외에도 셀의 상태에 해당하는 변수인 cell도 같이 고려해야 한다.

 

Pytorch를 통한 구현

"""
두 번째 단어를 입력으로 세 번째 단어가 무엇이 나올지 예측
"""

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

sentences = ['i like dog',
             'i love coffee',
             'i hate milk',
             'you like cat',
             'you love milk',
             'you hate coffee']
dtype = torch.float

# sentences에 출현하는 단어 목록
word_list = list(set(' '.join(sentences).split()))
# 각 단어에 해당하는 인덱스 매핑
word_dict = {w: i for i, w in enumerate(word_list)}
# 각 인덱스에 해당하는 단어 매핑
number_dict = {i: w for i, w in enumerate(word_list)}

# 모든 단어의 종류 수 : 9 (d)
n_class = len(word_dict)

# 문장의 수 (샘플의 수) : 6 (N)
batch_size = len(sentences)

# 은닉층 사이즈 (Dh)
n_hidden = 5


def make_batch(sentences):
    input_batch = []
    target_batch = []

    for sentence in sentences:
    	# 문장을 단어로 토크나이즈
        words = sentence.split()
        # 각 문장의 2번째 까지의 단어의 인덱스
        input_ = [word_dict[word] for word in words[:-1]]
        # 각 문장의 마지막 단어의 인덱스
        target_ = word_dict[words[-1]]

        """
        np.eye(n_class)[[7, 0]] -> [[0. 0. 0. 0. 0. 0. 0. 1. 0.] 
                                    [1. 0. 0. 0. 0. 0. 0. 0. 0.]]
        """
        input_batch.append(np.eye(n_class)[input_])  # One-Hot Encoding
        target_batch.append(target_)

    return input_batch, target_batch


# 텐서로 변환
input_batch, target_batch = make_batch(sentences)
input_batch = torch.tensor(input_batch, dtype=torch.float32, requires_grad=True)
target_batch = torch.tensor(target_batch, dtype=torch.int64)

print(input_batch.shape)  # N x T x D : (6, 2, 9)
print(target_batch.shape)  # (6,)


class TextLSTM(nn.Module):
    def __init__(self):
        super(TextLSTM, self).__init__()

        self.lstm = nn.LSTM(input_size=n_class, hidden_size=n_hidden, dropout=0.3)
        # 최종 출력 : yt = f(Wy*ht + b)
        # ht : (6, 5) / Wx : (5, 9) / b : (9,)
        self.W = nn.Parameter(torch.randn([n_hidden, n_class]).type(dtype))
        self.b = nn.Parameter(torch.randn([n_class]).type(dtype))  # (9,)

    def forward(self, X, hidden_and_cell):
        # (sequence의 길이, batch 크기, input vector 사이즈) 로 input_data의 shape를 변경하여,
        # mini-batch 단위의 학습이 진행 될 수 있도록 한다.
        # switch dim 0 and 1 : (6, 2, 9) -> (2, 6, 9)
        X = X.transpose(0, 1)

        # X : (2, 6, 9), hidden : (1, 6, 5)
        outputs, hidden = self.lstm(X, hidden_and_cell)

        # outputs : (2, 6, 5) (출력 갯수 x batch_size x n_hidden)
        # outputs : (6, 5) (맨 마지막 노드의 출력)
        # model : (6, 9)
        outputs = outputs[-1]
        model = torch.mm(outputs, self.W) + self.b

        return model


model = TextLSTM()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

for epoch in range(500):
    # hidden : 초기 은닉 상태 (Dh x 1) -> 그림에서 설명한 Dh x 1 은 사실 배치 사이즈가 1이라고 가정하고 1 x Dh x 1 이다.
    # 코드상으로는 배치사이즈가 6이므로 6 x 5 x 1 이지만, 양방향성 여부를 표현하면 1 x 6 x 5 x 1이 된다.
    # 양방향일시 zeros(2, ..)
    hidden = torch.zeros(1, batch_size, n_hidden, requires_grad=True)
    cell = torch.zeros(1, batch_size, n_hidden, requires_grad=True)
    output = model(input_batch, (hidden, cell))     # input_batch : (6, 2, 9), hidden : (1, 6, 5)
    loss = criterion(output, target_batch)  # output : (6, 9), target_batch : (6,)

    if (epoch + 1) % 100 == 0:
        print('Epoch : ', '%04d' % (epoch + 1), 'Cost : ', '{:.6f}'.format(loss))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

hidden = torch.zeros(1, batch_size, n_hidden, requires_grad=True)
cell = torch.zeros(1, batch_size, n_hidden, requires_grad=True)
predict = model(input_batch, (hidden, cell)).data  # (6, 9)
predict = predict.max(axis=1, keepdim=True)[1]  # (6, 1)
print([number_dict[n.item()] for n in predict.squeeze()])

 

참고

https://justkode.kr/deep-learning/pytorch-rnn

 

Pytorch로 RNN, LSTM 구현하기

Pytorch 에서는 CNN과 마찬가지로, RNN과 관련 된 API를 제공합니다. 이를 이용해 손쉽게 RNN 네트워크를 구축 할 수 있습니다. Recurrent Neural Network RNN (Recurrent Neural Network)를 위한 API는 torch

justkode.kr

 

댓글