[機械学習のはじめ方] Part43: RNNの構造とBackpropagation Through Time

機械学習

過去の情報を記憶するニューラルネットワークの基本

皆さん、こんにちは!👋 Step 8へようこそ。ここでは、時間の流れに沿って変化するデータ、つまり時系列データを扱うための強力なツール、リカレントニューラルネットワーク(RNN: Recurrent Neural Network)について学んでいきます。株価の予測、文章の生成、音声認識など、私たちの身の回りには時系列データがたくさんあります。RNNは、このようなデータの「順番」や「文脈」を捉えるのが得意なんです。

今回は、RNNがどのようにして過去の情報を記憶し、次の予測に活かすのか、その基本的な構造と、RNNを学習させるための重要なアルゴリズムであるBackpropagation Through Time (BPTT)について解説します。

1. RNNの構造:情報をループさせる仕組み 🔄

RNNの最大の特徴は、ネットワーク内部にループ構造を持つことです。これにより、過去の情報を保持し、それを現在の計算に利用することができます。イメージとしては、「短期記憶」を持つネットワークのようなものです。

具体的に見ていきましょう。時刻 t におけるRNNの隠れ状態(内部の状態) ht は、その時刻の入力 xt と、一つ前の時刻の隠れ状態 ht-1 の両方を使って計算されます。

数式で表すと、以下のようになります(f は活性化関数、Wb は重みやバイアスです):

隠れ状態の計算: ht = f(Whhht-1 + Wxhxt + bh)

この式が示すように、現在の隠れ状態 ht は、過去の情報(ht-1)と現在の入力(xt)の影響を受けて決まります。この ht が次の時刻 t+1 の計算に使われることで、情報が時系列に沿って伝播していくのです。

そして、各時刻 t での出力 yt は、その時刻の隠れ状態 ht を使って計算されます。

出力の計算: yt = g(Whyht + by)
(g は出力層の活性化関数)

重要なポイントは、RNNでは隠れ状態を計算するための重み(Whh, Wxh)や出力を計算するための重み(Why)が、全ての時刻で共通して使われることです。これにより、異なる長さの時系列データに対しても、同じモデルで対応できるようになります。✨

2. Backpropagation Through Time (BPTT):時間を遡る学習 ⏳

RNNの学習は、通常のニューラルネットワークで使われる誤差逆伝播法(Backpropagation)を時間方向に拡張した、Backpropagation Through Time (BPTT) というアルゴリズムで行われます。

BPTTの基本的な考え方は以下の通りです。

  1. 順伝播 (Forward Pass): 時系列データを最初から最後までRNNに入力し、各時刻での出力と隠れ状態を計算します。
  2. 損失の計算: 各時刻での出力と実際の正解データとの誤差(損失)を計算します。多くの場合、全ての時刻の損失を合計したものが全体の損失となります。
  3. 逆伝播 (Backward Pass): 計算された全体の損失を、時間を遡るようにして各パラメータ(重みやバイアス)で微分し、勾配を求めます。この「時間を遡る」部分がBPTTの特徴です。ある時刻 t のパラメータの勾配は、その時刻だけでなく、それより未来の時刻の計算にも影響を与えているため、未来から過去へと勾配を伝播させる必要があります。
  4. パラメータの更新: 計算された勾配を使って、最適化手法(例: Adam, SGD)によりパラメータを更新します。

RNNのループ構造により、BPTTでは勾配計算の際に同じ重みが繰り返し掛け合わされることになります。これが、次に説明するRNN特有の課題につながります。

3. BPTTの課題:勾配消失と勾配爆発 💣

BPTTには、特に長い時系列データを扱う際に問題となる点があります。それは勾配消失 (Vanishing Gradient)勾配爆発 (Exploding Gradient) です。

📉 勾配消失問題

時間を遡って勾配を計算する際、活性化関数の微分値などが繰り返し掛け合わされます。この値が1より小さい場合、過去に遡るほど勾配が指数関数的に小さくなり、ほとんど0に近づいてしまいます。これにより、遠い過去の情報が現在のパラメータ更新にほとんど寄与しなくなり、長期的な依存関係(例: 文章の最初の方の内容が文末に影響するなど)を学習するのが困難になります。

📈 勾配爆発問題

勾配消失とは逆に、掛け合わされる値が1より大きい場合、勾配が指数関数的に増大し、非常に大きな値になってしまうことがあります。これにより、学習が不安定になり、パラメータが発散してしまう可能性があります。勾配爆発は、勾配の大きさに上限を設ける勾配クリッピング (Gradient Clipping) という手法で比較的対処しやすいです。

これらの問題を緩和するため、計算コスト削減も兼ねて、一定のステップ数までしか時間を遡らない Truncated BPTT (打ち切りBPTT) という手法もよく用いられます。

今回は、RNNの基本的な構造と、その学習アルゴリズムであるBPTTについて学びました。

  • RNNはループ構造を持ち、過去の情報を隠れ状態として保持することで時系列データを扱います。
  • RNNの学習には、時間方向に誤差逆伝播を行うBPTTが用いられます。
  • BPTTには勾配消失勾配爆発といった課題があり、特に長期依存性の学習が難しい場合があります。

この基本的なRNNの課題、特に勾配消失問題を克服するために、より複雑な構造を持つLSTM (Long Short-Term Memory)GRU (Gated Recurrent Unit) といったモデルが開発されました。

次の記事では、これらの改良されたRNNモデルについて詳しく見ていきます。お楽しみに!😊

コメント

タイトルとURLをコピーしました