再帰的な潜在変数モデルの論文を読んだので紹介
NIPS 2015にアクセプトされたA Recurrent Latent Variable Model for Sequential Dataを紹介します。 2015と少し古いですが、基礎を疎かにするとなんたるやと言うように、系列データの深層生成モデルの基礎になる論文だと思います。他の論文の関連研究で結構出てくるVRNNとはこの論文のことを指します。
この論文は、深層生成モデルの定番VAEを音声などの系列データを生成するモデルに拡張しました。ざっくりいうとRNN +VAEといったところです。
お気持ちとしては、RNN(recurrent neural network)では、内部状態が決定論的であり、DBNs (dynamic Bayesian networks, 例:Hidden Markov models (HMMs) , Kalman filters)は時間遷移が確率論的で、その両方のいいとこ取りをするという論文です。
RNNとDBNsの共通点と相違点
共通点は:(1)内部の隠れ状態の発展を決める遷移関数、および(2)状態から出力へのマッピングがあること。
相違点は:(1)DBNsは内部状態の遷移が線形などシンプル、RNNは非線形で複雑であること。(2)内部状態が確率論的と決定論的。
相違点の(2)にこの論文は、着目しています。 自然な音声などに含まれる変動を確率的に捉えたいが、既存のRNNのモデルだとランダム性は、出力の条件付き確率分布の分散で捉えることになるが、それでは捉えきれない。そこで、内部状態に高次元の潜在変数を使い、自然な音声に含まれる変動をモデル化することが狙いです。
ノーテーション
系列データは、と表します。また、時刻に依存する隠れ状態をとするとき、RNNは再帰的非線形関数として表せます。さらに系列データ の潜在変数は、と定義します。
Variational Recurrent Neural Network
同時確率から各分布の構成を説明していき、最後に推論時の変分事後確率を説明します。 以下の式が同時確率分布になります。
尤度関数は、前時刻の隠れ状態と現時刻の潜在変数の非線形変換 に基づいてデータは生成されると仮定しています。
事前確率は、各時刻ごとにガウス分布に従いますが、標準ガウス分布ではなく一時刻前の隠れ状態 に依存して平均と分散が決まります。
, はNNです。 は特徴抽出器にあたるNNです。普通のRNNは、特徴抽出器と全結合NNの生成器だけを用いています。ちなみにVAEは、, を使用したモデルに対応します。それに比べ、VRNNは、, , , と多いですね。また、VRNNは通常のRNNと違い、再帰非線形関数 に潜在変数の特徴 が加わり、再帰非線形関数はで定義されます。 なぜ潜在変数そのまま使わず、潜在変数の特徴 にするのか分かりません。RNNに入力するのに、次元を揃える必要があるからなんでしょうかね。ちなみにに含まれる情報は、が込められています。
Inference
ここで、潜在変数の変分事後分布は、以下のガウス分布で定義します。
この潜在変数の事後分布は、に基づいているので、が与えられたもとでの事後分布ということになります。
時系列になると添字を注意深く見ないと依存関係が分かりづらいですが、論文中の以下の図を見ると依存関係が整理しやすいですね。Learning
パラメータの学習はELBOの最大化でおこないます。
実験
音声モデリング(Speech modelling)と手書き文字生成(Handwriting generation)で実験をしていました。音声モデリング(Speech modelling)の方だけ紹介します。音声モデリングのデータセットは4つ(Blizzard、TIMIT、Onomatopoeia、Accent)です。例えば、TIMITは、生の音声シグナル(話者:630, 英語の文章:6300 )のデータセットになります。他については興味があれば論文の方を見てください。そのデータを訓練とテストで分け、テストデータに対する対数周辺尤度の平均で生成のよさを定量的に評価しています。
また、Blizzardの訓練データに対するモデルの当てはまりの良さを、学習させたモデルから2秒間音声波形を生成させて定性的に評価しています。マクロとミクロ両方ともに比較手法に比べ、提案手法は高周波なノイズが少なく、データセットの信号に近い波形を生成できていることが分かります。
まとめ
自然な音声生成を応用に見据えつつ、RNNの再帰関数に時刻に依存する潜在変数も加えた深層生成モデルの紹介でした。事前確率がすでに、学習中の隠れ状態に依存させるのは、不思議な感覚です。この論文は、その後色々な論文に引用され変化を遂げています。強化学習の方ともつながりがあるので、いずれブログに書きたいです。