目次
こんにちは.高山です.
今回は,Transformer を用いた連続指文字認識モデルの補足になります.
実践手話認識 - モデル開発編の第五回では,Transformer を用いた連続指文字認識を紹介しました.
本編の記事では実装と実験に注力して説明しましたので,本記事では Transformer Decoder で使われている Multi-head cross attention (MHCA) の動作について Step by step で説明したいと思います.
MHCA の基本的な処理構成は Multi-head self-attention (MHSA) と同じ (第3節をご参照ください) ですが,Encoder の出力特徴量とラベル系列が入力になる点に注意してください.
1. 入力へ線形変換を適用
MHCA ではまず最初に,入力に対して,それぞれ異なる線形変換 \(\boldsymbol{W}^Q, \boldsymbol{W}^K, \boldsymbol{W}^V\) を適用します.
ラベル系列 \(\boldsymbol{Y}^{'}\) には\(\boldsymbol{W}^Q\) を適用し,Encoder の出力特徴量 \(\boldsymbol{Z}\) には \(\boldsymbol{W}^K, \boldsymbol{W}^V\) を,それぞれ適用します.
図中の赤四角に示すように,入力の行,および重みの列の積和演算結果が出力行列の要素になります.
2. 特徴量を次元毎に分割して各ヘッドに振り分け
MHSA と同様に MHCA もAttention の計算を,特徴次元に沿って小分けにした特徴量に対してそれぞれ行います.
ここで,小分けにした特徴量に対して行われる計算処理郡はヘッドと呼ばれます.
各ヘッドでそれぞれ Attention を計算することで,入力系列内の複雑な関係性を捉えることが可能になります.
図2は特徴量を次元軸に沿って分割し,各ヘッドに振り分けている様子を示しています.
特徴次元に沿って均等に分割されるため,入力特徴量の次元数が \(C^A\),ヘッド数が \(H\) だった場合は,各ヘッドに入力される特徴量の次元数は \(C^A/H\) になります.
この処理は \(Q, K, V\) 全てに対して行われます.
3. Scaled dot-product attention の適用
次に Scaled dot-procuct attention を用いて Attention 重みを計算します.
MHSA と異なり,\(\boldsymbol{\alpha}_h\) は \(L \times T^{'}\) 形状である点に注意してください.
MHCA の \(\boldsymbol{\alpha}_h\) では,各行は次ラベルを推論する際の Encoder 出力特徴量系列に対する重みを示しています (RNN Encoder-Decoder と同じです).
4. アテンションの適用
重みが計算できたら,重みと Value の行列積を計算します.
\(\boldsymbol{V}_h\) は Encoder の出力特徴量が基になっていますので,ここの処理で指文字動作の情報を取り込んでいることが分かります.
行列演算の結果 \(T^{'}\) は消えて出力系列 \(\boldsymbol{O}_h\) の長さは \(L\) になる点に注意してください.
5. 特徴量の結合
Attention 適用後は特徴次元に沿って小分けに分割された特徴量を結合します.
6. 内部特徴量へ線形変換を適用
最後に,結合後の特徴量を線形変換して次元数を調整します.
基本的には,入力と同じ次元数になるように設定します.
今回は Multi-head cross-attention の動作について細かく説明しましたが,如何でしょうか?
MHCA と MHSA は動作が同じで入力だけが異なるため,役割や内部計算の流れを混同しがちです.
実装時や資料を作成するときは注意したいですね (体験談です(^^;)).
今回紹介した話が,これから手話認識や深層学習を勉強してみようとお考えの方に何か参考になれば幸いです.