実践手話認識 - モデル開発編4: RNN Encoder-Decoderを用いた連続指文字認識モデル

This image is generated with ChatGPT-4 Omni, and edited by the author.
作成日: 2024年09月13日(金) 00:00
最終更新日: 2024年10月08日(火) 14:25
カテゴリ: 手話言語処理
タグ:  実践手話認識 連続指文字認識 RNN Python

RNN Encoder-Decoderを用いた連続指文字認識モデルを紹介します.題材は連続指文字認識ですが,手法は連続手話単語認識や手話翻訳にも応用可能です.

こんにちは.高山です.
実践手話認識 - モデル開発編の第四回になります.

以前の記事で,Kaggle の Google American Sign Language Fingerspelling Recognition データセット (以下,GAFS データセット) について解説しました.
今回は,GAFS データセットを用いて連続指文字認識モデルを学習してみたいと思います.
具体的には,Recurrent neural network (RNN) ベースの Encoder-Decoder モデル [Bahdanau'15] の実装方法を紹介します.

下記に示すように,RNN については孤立手話単語認識を題材にして,複数回に分けて解説をしてきました.

RNN や Attention などの各計算処理についてはある程度カバーできていると思います.
今回は Encoder-Decoder の実装部分に注力して解説していきたいと思います.

全体を通して 1 記事にまとめるとかなり長くなってしまいますので,本記事ではモデルの実装と実験に注力します.
学習処理全体の実装は,Encoder-Decoder 解説記事をご参照ください.

今回解説するスクリプトはGitHub上に公開しています
ただし,説明で使用しているデータセットは,配布可能な容量にするために大部分のデータを削除しています.
デバッグには使用できますが,学習は安定しませんのでご注意ください.

  • [Bahdanau'15]: D. Bahdanau, et al., "Neural Machine Translation by Jointly Learning to Align and Translate," Proc. of the ICLR, available here, 2015.

更新履歴 (大きな変更のみ記載しています)

  • 2024/09/24: Embedding 層と出力層の初期化処理を加えて再実験を行い,結果を更新.
    学習初期の安定性が向上し,最終的な認識性能も向上しました.
  • 2024/09/18: カテゴリを変更しました
  • 2024/09/17-2:
    • タグを更新しました
    • 系列認識 (Sequential recognition) は別の手法を連想させる可能性があるので,時系列認識 (temporal recognition) という用語を用いるようにしました
  • 2024/09/17: 2024/09/14 の修正を反映したコードで実験結果を更新.
    途中で瞬間的に学習結果が悪くなる現象がありましたが,最終的には同程度以上の認識性能に落ち着きました.
  • 2024/09/14
    • 旧処理で最終ラベルの推論が行われないバグがあったっため,RNNCSLR の forward() メソッドを更新.
    • テスト時の最大ループ数 max_seqlen を 31 から 60 に変更.
      正しくは <eos> の分を含め 32 に設定すべきでした.
      また,本来はテストデータの最大ラベル長は未知なので,学習時よりも大きい値を設定すべきと判断して,現在の値に変更しました.

1. RNN Encoder-Decoderの処理

1.1 全体構成

RNN Encoder-Decoder を基にした連続指文字認識モデルの構成図を図1に示します.

Bahdanau RNN with Attentionの処理構成と処理の流れを説明するブロック図です.画像の後に説明があります.
Bahdanau RNN with Attention [Bahdanau'15]

上記のモデルは,Encoder と Decoder の双方を RNN ベースで組み,Decoder に Attention 層を持ちます.

Encoder では RNN を使用して入力系列 \(\boldsymbol{x}_t\) の特徴変換を行います.
変換後の特徴系列 \(\boldsymbol{H}^e = \{\boldsymbol{h}^e_t | 1 \leq t \leq T\}\) は Attention 層へ入力されます.

Decoder では RNN を使用して現在の入力ラベル \(\boldsymbol{y}_s\) から次のラベル予測確率 \(\hat{\boldsymbol{p}}_{s}\) を出力します.
(正確には確率化前の応答値を出力します)
Attention 層はこのとき,Encoder の特徴系列 \(\boldsymbol{H}^e\) と一つ前の Decoder-RNN の出力 \(\boldsymbol{h}^d_{s-1}\) を入力として,次のラベルを予測するために \(\boldsymbol{H}^e\) のどこに注目すべきかを示す重み \(\alpha_{st}\) を生成します.

\(\alpha_{st}\) は Attention 重み (Attention weight) と呼ばれます.
Attention 層はこの重みを用いて次のラベルを予測するための要約特徴量である,コンテキストベクトル \({c}_{s}\) を出力します.

コンテキストベクトル \({c}_{s}\) を現在の入力データ \(\boldsymbol{y}_s\) とともに RNN へ入力することで,適応的に \(\boldsymbol{H}^e\) の特徴を取り込んで次ラベルを予測することができます.

なお,"Emb" は Embedding 層を指し,離散数値をモデルが扱いやすい多次元特徴量に変換します.

  • [Bahdanau'15]: D. Bahdanau, et al., "Neural Machine Translation by Jointly Learning to Align and Translate," Proc. of the ICLR, available here, 2015.

1.2 Attentionの計算

Attention 層の構成図を図2に示します.

Attention層の処理構成,処理の流れ,および計算内容について説明する画像です.画像の後に説明があります.
Attentionの計算

Attention の処理を特徴付けるのは \(e_{st} = score()\) という計算です.
(Energy と呼ばれる場合もあります.物理的な話ではなくて応答値くらいの意味だと思いますが,正確なニュアンスは英語力不足で分かっていません(^^;))

Bahdanau attention の処理構成は下記のようになります.

  • Linear 層: Encoder と Decoder の RNN 出力をそれぞれ変換し特徴次元を揃える
  • tanh: tanh 関数で応答値のスケールを揃える
  • Linear 層: 1次元の系列に変換
  • Softmax: 確率値に変換

Dot-product attention [Loung'15] (入力間の類似度を評価) のように解釈がしやすい処理ではないですが,学習の結果として予測に重要な箇所に対して高い確率値を出力するようになります.

  • [Luong'15]: M.-T. Luong, et al., "Effective Approaches to Attention-based Neural Machine Translation," Proc. of the EMNLP, available here, 2015.

1.3 Attention の出力例

学習済みモデルによるAttention の出力例を図3に示します.
(見やすさを優先して 5 ラベル分だけ出力しています)

Attention層の出力例を示すグラフです.画像の前後に説明があります.
Attentionの出力例

ラベルの数値は \(59\) が "\(\text{<sos>}\)" を示し,他の数値はそれぞれ特定の指文字を示しています.
連続指文字認識では1個づつ指文字を順番に表出していきますので,重みのピークが後ろの時間にずれていっています.
重なっている箇所も多いので,個々の指文字を正確に見分けているわけではなく,系列全体の情報を使って推測しているようです.

このように,RNN Encoder-Decoder モデルでは入力の重要な部分に重み付けをしながら処理を行うことで,認識性能を向上させています.

1.4 実装上の注意点

細かな点ですが,Decoder RNN の隠れ状態の管理は注意が必要です.

Decoder の推論ではループ処理をしながら外部で RNNの隠れ状態を管理します.
まず,隠れ状態の管理を忘れると RNN の隠れ状態はループ毎に 0 初期化されるので推論が上手くいきません.
そのため,通常は隠れ状態の RNN を初期化後にループの内部では隠れ状態を保持する実装をします.

Decoder RNNの隠れ状態初期化は図4に示す処理がよく使われます.

Decoderの隠れ状態の適切な初期化が必要であることを示す画像です.画像の前後に説明があります.
Decoder hidden state の初期化

0 初期化をする場合にありがちなミスは,各サンプルの推論開始前に隠れ状態を毎回初期化することを忘れることです.
これは推論処理とモデルをアプリケーションに実装して,単発の推論を繰り返す場合によくやってしまいます.

2回目の推論から上手くいかなくなるような現象が発生するので,バグを見つけるのがかなり大変です (体験談です(^^;)) .

Encoder から隠れ状態を受け渡す場合は,Encoder と Decoder の RNN 構成に注意が必要です.
Encoder は入力系列を一度に受け取るので (オンライン認識の場合はこの限りではありません),Bidirectional RNN を用いることができます.
一方,テスト時に Decoder は過去のラベル系列しか利用できないので,Bidirectional RNN を用いることができません.

PyTorchのRNNクラスの出力は少しクセがあり,Encoder と Decoder のRNN 構成に差がある場合は,特徴量の整形を行う必要があるので注意してください.

本記事では,Encoder の隠れ状態を Decoder の初期隠れ状態として受け渡す実装を用いて実験をします.

2. 実験結果

次節以降では,実装の紹介をしていきます.
実装部分はかなり長いので結果を先にお見せしたいと思います.

説明で使用しているデータセットは容量をかなり削っていて学習が不安定ですので,ここでは全データセットを用いた結果だけ示します.

今回の実験では学習を安定させるために,ラベルスムージングと各種のデータ拡張処理を導入しています.
これらの処理については「手話認識入門」の記事で解説しているので,よろしければご一読ください.

学習エポック数は50,バッチ数は128に設定しています.

図5は,Validation Lossと単語誤り率の推移を示しています.
連続指文字認識のような時系列認識タスクの評価には,単語誤り率 (WER: Word error rate) がよく用いられます.
WER については WERに関する補足記事で説明していますので,併せてご一読いただたらうれしいです.

認識性能を図示したグラフです.画像前後の文章に詳細説明があります.
認識性能

横軸は学習・評価ループの繰り返し数 (Epoch) を示し,縦軸は評価指標を示します.

\(\text{WER}=30.0\%\) 程度でしたので,約 \(70.0\%\) 程度は正しい指文字が認識できていることになります.
今回のデータセットは学習に非常に時間がかかるので 50 エポックで打ち切りましたが,学習を続ければ性能はまだ伸びると思います.

標準的なモデルとしてはまずまずな性能だと思います.
(Kaggle の時よりも性能が良い気がしますので,テスト用に選んだ被験者の指文字が比較的認識しやすいデータなのかもしれません)

なお,今回の実験では話を簡単にするために,実験条件以外のパラメータは固定にし,乱数の制御もしていません.
必ずしも同様の結果になるわけではないので,ご了承ください.

3. 前準備

3.1 データセットのダウンロード

ここからは実装方法の説明をしていきます.
まずは,前準備として Google Colab にデータセットをアップロードします.

まず最初に,データセットの格納先からデータをダウンロードし,ご自分の Google drive へアップロードしてください.

次のコードで Google drive を Colab へマウントします.
Google Drive のマウント方法については,補足記事にも記載してあります.

1
2
3
from google.colab import drive

drive.mount("/content/drive")

ドライブ内のファイルを Colab へコピーします.
パスはアップロード先を設定する必要があります.

# Copy to local.
!cp ./drive/MyDrive/Datasets/gafs_dataset_very_small.zip gafs_dataset.zip

データセットは ZIP 形式になっているので unzip コマンドで解凍します.

!unzip -o gafs_dataset.zip
Archive:  gafs_dataset.zip
   creating: gafs_dataset_very_small/
  inflating: gafs_dataset_very_small/0.hdf5
  ...
  inflating: gafs_dataset_very_small/LICENSE.txt

成功すると gafs_dataset_very_small 以下にデータが解凍されます.
HDF5 ファイルはデータ本体で,手話者毎にファイルが別れています.
JSON ファイルは辞書ファイルで,TXT ファイルは本データセットのライセンスです.

!ls gafs_dataset_very_small
0.hdf5 135.hdf5 ... character_to_prediction_index.json LICENSE.txt

辞書には指文字名と数値の関係が 59 種類分定義されています.

!cat gafs_dataset_very_small/character_to_prediction_index.json
{
    " ":0,
    "!":1,
    ...
    "~":58
}

ライセンスはオリジナルと同様に,CC-BY 4.0 としています.

!cat gafs_dataset_very_small/LICENSE.txt
The dataset provided by Natsuki Takayama (Takayama Research and Development Office) is licensed under CC-BY 4.0.
Author: Copyright 2024 Natsuki Takayama
Title: GASF very small dataset
Original licenser: Google LLC
Modification
- Extract only 3 parquet file.
- Packaged into HDF5 format.

次のコードでサンプルを確認します.
サンプルは辞書型のようにキーバリュー形式で保存されており,下記のように階層化されています.

- サンプルID (トップ階層のKey)
  |- feature: 入力特徴量で `[C(=2), T, J(=543)]` 形状.C,T,Jは,それぞれ特徴次元,フレーム数,追跡点数です.
  |- token: 指文字ラベル系列で `[L]` 形状.0から58の数値です.

なお,データ量削減のため,入力特徴は \((x, y)\) 座標になっており \(z\) 座標は含んでいません.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
import h5py
with h5py.File("gafs_dataset_very_small/0.hdf5", "r") as fread:
    keys = list(fread.keys())
    print(keys[:10])
    group = fread[keys[0]]
    print(group.keys())
    feature = group["feature"][:]
    token = group["token"][:]
    print(feature.shape)
    print(token)
['1720198121', '1722303176', '1723157122', '1731934631', '1737624109', '1739256200', '1743069372', '1743412187', '1744795751', '1746320345']
<KeysViewHDF5 ['feature', 'token']>
(2, 271, 543)
[14 38 32 45 44 36 40 32 43 43 36 56 14 43 40 45 32 12 34 32 49 50 51 36
 45 50]

3.2 モジュールのダウンロード

次に,過去の記事で実装したコードをダウンロードします.
コードは Githubのsrc/modules_gislr にアップしてあります (今後の記事で使用するコードも含まれています).

まず,下記のコマンドでレポジトリをダウンロードします.
今後のアップデートを考慮してバージョン指定をしていますので注意してください.

!wget https://github.com/takayama-rado/trado_samples/archive/refs/tags/v0.3.4.zip -O master.zip
--2024-09-10 06:19:15--  https://github.com/takayama-rado/trado_samples/archive/refs/tags/v0.3.4.zip
...
2024-09-10 06:19:20 (17.8 MB/s) - ‘master.zip’ saved [80254068]

ダウンロードしたリポジトリを解凍します.

!unzip -o master.zip -d master
Archive:  master.zip
3406d5a0072e08879272e622ff8efdc1c7b78ee8
   creating: master/trado_samples-0.3.4/
   inflating: master/trado_samples-0.3.4/.gitignore
   ...

モジュールのディレクトリをカレントディレクトリに移動します.

!mv master/trado_samples-0.3.4/src/modules_gislr .

他のファイルは不要なので削除します.

!rm -rf master master.zip gafs_dataset_very_small.zip
!ls
drive  gafs_dataset_very_small  gafs_dataset.zip  modules_gislr  sample_data

3.3 モジュールのロード

主要な処理の実装に先立って,下記のコードでモジュールをロードします.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import copy
import json
import os
import sys
import time
from functools import partial
from pathlib import Path

# Third party's modules
import numpy as np

import torch

from nltk.metrics.distance import edit_distance

from torch import nn
from torch.nn import functional as F
from torch.utils.data import (
    DataLoader)

from torchvision.transforms import Compose

# Local modules
sys.path.append("modules_gislr")
from modules_gislr.dataset import (
    HDF5Dataset,
    merge)
from modules_gislr.defines import (
    get_fullbody_landmarks
)
from modules_gislr.layers import (
    RNNEncoder,
    apply_norm,
    create_norm
)
from modules_gislr.train_functions import (
    LabelSmoothingCrossEntropyLoss
)
from modules_gislr.transforms import (
    PartsBasedNormalization,
    ReplaceNan,
    SelectLandmarksAndFeature,
    ToTensor
)
from modules_gislr.utils import (
    select_reluwise_activation
)
【コード解説】
- 標準モジュール
  - copy: データコピーライブラリ.Macaron Netブロック内でEncoder層をコピーするために使用します.
  - json: JSONファイル制御ライブラリ.辞書ファイルのロードに使用します.
  - os: システム処理ライブラリ
  - sys: Pythonインタプリタの制御ライブラリ.
    今回はローカルモジュールに対してパスを通すために使用します.
  - time: 時刻処理ライブラリ.処理時間計測に使用します.
  - functools: 関数オブジェクトを操作するためのライブラリ.
    今回はDataLoaderクラスに渡すパディング関数に対して設定値をセットするために使用します.
  - pathlib.Path: オブジェクト指向のファイルシステム機能.
    主にファイルアクセスに使います.osモジュールを使っても同様の処理は可能です.
    高山の好みでこちらのモジュールを使っています(^^;).
- 3rdパーティモジュール
  - numpy: 行列演算ライブラリ
  - nltk: 自然言語処理ライブラリ.WERを計算するために使用します.
  - torch: ニューラルネットワークライブラリ
  - torchvision: PyTorchと親和性が高い画像処理ライブラリ.
    今回はDatasetクラスに与える前処理をパッケージするために用います.
- ローカルモジュール: sys.pathにパスを追加することでロード可能
  - dataset: データセット操作用モジュール
  - defines: 各部位の追跡点,追跡点間の接続関係,およびそれらへのアクセス処理を
    定義したモジュール
  - layers: ニューラルネットワークのモデルやレイヤモジュール
  - transforms: 入出力変換処理モジュール
  - train_functions: 学習・評価処理モジュール
  - utils: 汎用処理関数モジュール

4. 認識モデルの実装

4.1 Attention module

ここから先は,認識モデルを実装していきます.
まず最初に第1.2項で説明した Attention module を実装します.

下記のコードで Attention の Energy 項計算処理を実装します.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class BahdanauAttentionEnergy(nn.Module):
    def __init__(self,
                 key_dim,
                 query_dim,
                 att_dim,
                 add_bias=False):
        super().__init__()

        self.w_key = nn.Linear(key_dim, att_dim, bias=add_bias)
        self.w_query = nn.Linear(query_dim, att_dim, bias=add_bias)
        self.w_out = nn.Linear(att_dim, 1, bias=add_bias)

    def forward(self, key, query):
        # key: `[N, key_len, key_dim]`
        # query: `[N, 1, query_dim]`
        key = self.w_key(key)
        query = self.w_query(query)
        # Adding with broadcasting.
        # key: `[N, key_len, key_dim]`
        # query: `[N, 1, query_dim]`
        # query should be broadcasted to `[N, key_len, query_dim]`
        temp = key + query
        # `[N, key_len, att_dim] -> [N, key_len, 1] -> [N, 1, key_len]`
        energy = self.w_out(torch.tanh(temp))
        energy = torch.permute(energy, [0, 2, 1])
        return energy
【コード解説】
- 引数:
  - key_dim: Key値の入力次元数.通常はEncoder側RNNの出力次元数と揃えます.
  - query_dim: Query値の入力次元数.通常はDecoder側RNNの出力次元数と揃えます.
  - att_dim: Attention内部の中間次元数.
  - add_bias: Trueの場合,Attention内部のLinear層にBias項を加えます.
- 7-11行目: 初期化処理
- 13-26行目: 推論処理
  出力は`[N, 1, key_len]` 形状にして返しています.

次にAttention 処理の全体を実装します.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class SingleHeadAttention(nn.Module):
    def __init__(self,
                 key_dim,
                 query_dim,
                 att_dim,
                 add_bias):
        super().__init__()

        self.att_energy = BahdanauAttentionEnergy(
            key_dim=key_dim,
            query_dim=query_dim,
            att_dim=att_dim,
            add_bias=add_bias)

        self.neg_inf = None

    def forward(self,
                key,
                value,
                query,
                mask=None):
        if self.neg_inf is None:
            self.neg_inf = float(np.finfo(
                torch.tensor(0, dtype=key.dtype).numpy().dtype).min)

        batch, klen, kdim = key.shape
        _, qlen, qdim = query.shape
        energy = self.att_energy(key=key, query=query)
        assert energy.shape == (batch, qlen, klen)

        # Apply mask.
        if mask is not None:
            if len(mask.shape) == 2:
                # `[N, klen] -> [N, qlen(=1), klen]`
                mask = mask.unsqueeze(1)
            # Negative infinity should be 0 in softmax.
            energy = energy.masked_fill_(mask==0, self.neg_inf)

        # Compute attention mask.
        attw = torch.softmax(energy, dim=-1)
        # attw: `[N, qlen, klen]`
        # value: `[N, klen, kdim]`
        # bmm: `[N, qlen, klen] x [N, klen, kdim] -> [N, qlen, kdim]`
        cvec = torch.bmm(attw, value)
        return cvec, attw
【コード解説】
- 引数:
  - key_dim: Key値の入力次元数.通常はEncoder側RNNの出力次元数と揃えます.
  - query_dim: Query値の入力次元数.通常はDecoder側RNNの出力次元数と揃えます.
  - att_dim: Attention内部の中間次元数.
  - add_bias: Trueの場合,Attention内部のLinear層にBias項を加えます.
- 7-15行目: 初期化処理
- 17-45行目: 推論処理
  - 22-24行目: マスキング値の初期化.
    Softmax関数は通常のマスキング値 (0など) を有効な数値に変換してしまうので,
    計算後の重みが0になるように,負の最小値を設定します.
    入力の型をハードコーディングすると移植性が悪くなるので,最初の入力の時に型を
    判定するように実装しています.
  - 26-29行目: Energy項を算出.
  - 32-37行目: Paddin箇所などマスキングを行う部分に対して,`neg_inf` を代入しています.
    代入箇所はsoftmax()処理内で0に変換されます.
  - 40行目: Attention重みを算出.
  - 44行目: Context vectorを算出.
    torch.bmm はバッチ内のサンプル毎に行列積を算出する関数です.

Energy 項と異なるクラスに実装することで他の種類の Attention へ切り替え易くしています.
他の Attention 層については割愛させていただきますが,Githubのsrc/modules_gislr のコードには実装してあります.
機会があれば別記事で紹介させていただきます.

4.2 RNN Decoder

第1.1項で説明した RNN Decoder を次のコードで実装します.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
class BahdanauRNNDecoder(nn.Module):
    def __init__(self,
                 in_channels,
                 hidden_channels,
                 out_channels,
                 emb_channels,
                 att_dim,
                 att_add_bias,
                 rnn_type,
                 num_layers,
                 activation,
                 dropout,
                 padding_val,
                 proj_size=0):
        super().__init__()
        assert rnn_type in ["srnn", "lstm", "gru"]

        self.emb_layer = nn.Embedding(
            num_embeddings=out_channels,
            embedding_dim=emb_channels,
            padding_idx=padding_val)

        self.att_layer = SingleHeadAttention(
            key_dim=in_channels,
            query_dim=hidden_channels,
            att_dim=att_dim,
            add_bias=att_add_bias)

        if rnn_type == "srnn":
            self.rnn = nn.RNN(input_size=in_channels + emb_channels,
                              hidden_size=hidden_channels,
                              num_layers=num_layers,
                              nonlinearity=activation,
                              batch_first=True,
                              dropout=dropout,
                              bidirectional=False)
        elif rnn_type == "lstm":
            self.rnn = nn.LSTM(input_size=in_channels + emb_channels,
                               hidden_size=hidden_channels,
                               num_layers=num_layers,
                               batch_first=True,
                               dropout=dropout,
                               bidirectional=False,
                               proj_size=proj_size)
        elif rnn_type == "gru":
            self.rnn = nn.GRU(input_size=in_channels + emb_channels,
                              hidden_size=hidden_channels,
                              num_layers=num_layers,
                              batch_first=True,
                              dropout=dropout,
                              bidirectional=False)
        self.head = nn.Linear(hidden_channels, out_channels)

        self.num_layers = num_layers
        self.dec_hstate = None
        self.attw = None

        self.reset_parameters(emb_channels, padding_val)

    def reset_parameters(self, embedding_dim, padding_val):
        # Bellow initialization has strong effect to performance.
        # Please refer.
        # https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/transformer/transformer_base.py#L189
        nn.init.normal_(self.emb_layer.weight, mean=0, std=embedding_dim**-0.5)
        nn.init.constant_(self.emb_layer.weight[padding_val], 0)

        # Please refer.
        # https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/transformer/transformer_decoder.py
        nn.init.xavier_uniform_(self.head.weight)
        nn.init.constant_(self.head.bias, 0.0)

    def init_dec_hstate(self, enc_hstate, init_as_zero=False):
        if init_as_zero:
            dec_hstate = torch.zeros_like(enc_hstate)
        else:
            dec_hstate = enc_hstate
        # To avoid error at RNN layer.
        self.dec_hstate = dec_hstate.contiguous()

    def forward(self,
                dec_inputs,
                enc_seqs,
                enc_mask):
        assert self.dec_hstate is not None, f"dec_hstate has not been initialized."
        dec_hstate = self.dec_hstate

        # Attention layer requires hidden state of 2nd rnn layer.
        # as `[N, 1, C]`
        query = dec_hstate[-1].unsqueeze(1)
        cvec, self.attw = self.att_layer(
            key=enc_seqs,
            value=enc_seqs,
            query=query,
            mask=enc_mask)

        emb_out = self.emb_layer(dec_inputs)
        # `[N, C] -> [N, 1, C]`
        emb_out = emb_out.reshape([-1, 1, emb_out.shape[-1]])
        feature = torch.cat([cvec, emb_out], dim=-1)
        if isinstance(self.rnn, nn.LSTM):
            hidden_seqs, (last_hstate, last_cstate) = self.rnn(feature,
                                                               dec_hstate)
        else:
            hidden_seqs, last_hstate = self.rnn(feature,
                                                dec_hstate)
            last_cstate = None

        output_dec = self.head(hidden_seqs)
        self.dec_hstate = last_hstate
        return output_dec
【コード解説】
- 引数:
  - in_channels: 入力特徴量の次元数.通常はEncoderの出力次元数を揃えます.
  - hidden_channels: RNNの出力次元数.
  - out_channels: 出力次元数.通常はラベルの種類数に揃えます.
  - emb_channels: Embedding層の出力次元数.
  - att_dim: Attention内部の中間次元数.
  - att_add_bias: Trueの場合,Attention内部のLinear層にBias項を加えます.
  - rnn_type: RNN層の種別を指定 [srnn/lstm/gru].
  - num_layers: RNN層の数
  - activation: RNN層内の活性化関数.
    ["tanh"/"relu"]で指定します.
  - dropout: Dropoutレイヤの欠落率
  - padding_val: ラベル系列のパディング値.
  - proj_size: 0以上の場合,H. Sakらの拡張型LSTM層を使用します.
    https://arxiv.org/abs/1402.1128
- 15-58行目: 初期化処理
  RNNの入力次元数は `in_channels + emb_channels` となる点に注意してください.
  58行目で Embedding層と出力層の学習パラメータを初期化しています.
- 60-70行目: 学習パラメータ初期化処理
- 72-78行目: RNN 隠れ状態の初期化処理
  この処理は Decoder の推論開始前に毎回呼び出す必要があります.
- 80-110行目: 推論処理
  - 84行目: 隠れ状態が初期化されているかを確認.
  - 85-94行目: Attention層を適用.
  - 96行目: Embedding層を適用.
  - 98-99行目: RNN層への入力を整形.
  - 100-106行目: RNN層を適用.
  - 108行目: ラベルへの応答値に変換.
  - 109行目: 次のループに備えて隠れ状態を格納.

nn.Embedding 層 (と出力層) のパラメータ初期化を reset_parameter() で行っている点に注意してください.

4.3 認識モデル

第1.1項で説明した 認識モデル全体を次のコードで実装します.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
class RNNCSLR(nn.Module):
    def __init__(self,
                 enc_in_channels,
                 enc_hidden_channels,
                 enc_rnn_type,
                 enc_num_layers,
                 enc_activation,
                 enc_bidir,
                 enc_dropout,
                 enc_apply_mask,
                 enc_proj_size,
                 dec_in_channels,
                 dec_hidden_channels,
                 dec_out_channels,
                 dec_emb_channels,
                 dec_att_dim,
                 dec_att_add_bias,
                 dec_rnn_type,
                 dec_num_layers,
                 dec_activation,
                 dec_dropout,
                 dec_padding_val,
                 dec_proj_size):
        super().__init__()
        self.enc_bidir = enc_bidir

        self.linear = nn.Linear(enc_in_channels, enc_hidden_channels)
        self.enc_activation = nn.ReLU()

        self.encoder = RNNEncoder(
            in_channels=enc_hidden_channels,
            out_channels=enc_hidden_channels,
            rnn_type=enc_rnn_type,
            num_layers=enc_num_layers,
            activation=enc_activation,
            bidir=enc_bidir,
            dropout=enc_dropout,
            apply_mask=enc_apply_mask,
            proj_size=enc_proj_size)

        if enc_bidir:
            dec_in_channels *= 2
            dec_hidden_channels *= 2
            dec_att_dim *= 2

        self.decoder = BahdanauRNNDecoder(
            in_channels=dec_in_channels,
            hidden_channels=dec_hidden_channels,
            out_channels=dec_out_channels,
            emb_channels=dec_emb_channels,
            att_dim=dec_att_dim,
            att_add_bias=dec_att_add_bias,
            rnn_type=dec_rnn_type,
            num_layers=dec_num_layers,
            activation=dec_activation,
            dropout=dec_dropout,
            padding_val=dec_padding_val,
            proj_size=dec_proj_size)

    def _apply_encoder(self, feature, feature_pad_mask=None):
        # Feature extraction.
        # `[N, C, T, J] -> [N, T, C, J] -> [N, T, C*J] -> [N, T, C']`
        N, C, T, J = feature.shape
        feature = feature.permute([0, 2, 1, 3])
        feature = feature.reshape(N, T, -1)

        feature = self.linear(feature)
        feature = self.enc_activation(feature)

        # Apply encoder.
        enc_seqs, enc_hstate = self.encoder(feature, feature_pad_mask)[:2]

        # Basically, decoder should not be bidirectional.
        # So, we should concatenate backwarded feature.
        if self.enc_bidir:
            # `[2*layers, N, C] -> [layers, N, 2*C]`
            enc_hstate = torch.permute(enc_hstate, [1, 0, 2])
            enc_hstate = enc_hstate.reshape([enc_hstate.shape[0],
                                             enc_hstate.shape[1] // 2,
                                             -1])
            enc_hstate = torch.permute(enc_hstate, [1, 0, 2])
        return enc_seqs, enc_hstate

    def forward(self,
                feature, tokens,
                feature_pad_mask=None, tokens_pad_mask=None):
        """Forward computation for train.
        """
        enc_seqs, enc_hstate = self._apply_encoder(feature, feature_pad_mask)

        # Apply decoder.
        self.decoder.init_dec_hstate(enc_hstate)
        preds = None
        for t_index in range(0, tokens.shape[-1]):
            # Teacher forcing.
            dec_inputs = tokens[:, t_index].reshape([-1, 1])
            pred = self.decoder(
                dec_inputs=dec_inputs,
                enc_seqs=enc_seqs,
                enc_mask=feature_pad_mask)
            if preds is None:
                preds = pred
            else:
                # `[N, T, C]`
                preds = torch.cat([preds, pred], dim=1)
        return preds

    def inference(self,
                  feature,
                  start_id,
                  end_id,
                  feature_pad_mask=None,
                  max_seqlen=62):
        """Forward computation for test.
        """
        enc_seqs, enc_hstate = self._apply_encoder(feature, feature_pad_mask)

        # Apply decoder.
        self.decoder.init_dec_hstate(enc_hstate)
        dec_inputs = torch.tensor([start_id]).to(feature.device)
        # `[N, T]`
        dec_inputs = dec_inputs.reshape([1, 1])
        preds = None
        pred_ids = [start_id]
        for _ in range(max_seqlen):
            pred = self.decoder(
                dec_inputs=dec_inputs,
                enc_seqs=enc_seqs,
                enc_mask=feature_pad_mask)
            if preds is None:
                preds = pred
            else:
                # `[N, T, C]`
                preds = torch.cat([preds, pred], dim=1)

            pid = torch.argmax(pred, dim=-1)
            dec_inputs = pid

            pid = pid.reshape([1]).detach().cpu().numpy()[0]
            pred_ids.append(int(pid))
            if int(pid) == end_id:
                break

        # `[N, T]`
        pred_ids = np.array([pred_ids])
        return pred_ids, preds
【コード解説】
- 引数
  - enc_in_channels: Encoder側入力特徴量の次元数.
  - enc_hidden_channels: Encoder側RNNの出力次元数.
  - enc_rnn_type: Encoder側RNN層の種別を指定 [srnn/lstm/gru].
  - enc_num_layers: Encoder側RNN層の数.
  - enc_activation: Encoder側RNN層の活性化関数.
    ["tanh"/"relu"]で指定します.
  - enc_bidir: Trueの場合,Encoder側RNN でBidirectional RNNを使用.
    この場合,Encoderの出力次元数は `enc_hidden_channels * 2` になります.
  - enc_dropout: EncoderのDropoutレイヤの欠落率.
  - enc_apply_mask: Trueの場合,Encoder側RNNでマスキングを行います.
  - enc_proj_size: 0以上の場合,Encoder側RNNでH. Sakらの拡張型LSTM層を使用します.
    https://arxiv.org/abs/1402.1128
  - dec_in_channels: Decoder側入力特徴量の次元数.通常はEncoderの出力次元数を揃えます.
  - dec_hidden_channels: Decoder側RNNの出力次元数.
  - dec_out_channels: Decoderの出力次元数.通常はラベルの種類数に揃えます.
  - dec_emb_channels: Decoder Embedding層の出力次元数.
  - dec_att_dim: Decoder Attention内部の中間次元数.
  - dec_att_add_bias: Trueの場合,Decoder Attention内部のLinear層にBias項を加えます.
  - dec_rnn_type: Decoder側RNN層の種別を指定 [srnn/lstm/gru].
  - dec_num_layers: Decoder側RNN層の数.
  - dec_activation: Decoder側RNN層の活性化関数.
    ["tanh"/"relu"]で指定します.
  - dec_dropout: DecoderのDropoutレイヤの欠落率.
  - dec_padding_val: ラベル系列のパディング値.
  - dec_proj_size: 0以上の場合,Decoder側RNNでH. Sakらの拡張型LSTM層を使用します.
    https://arxiv.org/abs/1402.1128
- 24-58行目: 初期化処理.
  EncoderがBidirectional構成の場合は,41-44行目でDecoderの次元数を倍にしている点に
  注意してください.
- 60-82行目: Encoderの適用処理.
  EncoderがBidirectional構成の場合は,75-81行目で隠れ状態の形状をDecoderに合わせて
  変えている点に注意してください.
- 84-106行目: 学習時推論処理
  - 89行目: Encoderの適用.
  - 92行目: Decoderの隠れ状態を初期化.
    Encoderの隠れ状態をDecoderの初期状態としています.この処理を忘れた場合,
    Decoderの初期隠れ状態は前のサンプル処理時の最終隠れ状態になるため,
    推論が上手くできなくなります.
  - 94-105行目: 推論のメインループ
    - 96行目: 学習時は正解ラベルをDecoderの入力ラベルとします.
    - 105行目: 時間軸に沿って応答値を積み上げていきます.
- 108-146行目: テスト時推論処理
  - 116行目: Encoderの適用.
  - 119行目: Decoderの隠れ状態を初期化.
    Encoderの隠れ状態をDecoderの初期状態としています.この処理を忘れた場合,
    Decoderの初期隠れ状態は前のサンプル処理時の最終隠れ状態になるため,
    推論が上手くできなくなります.
  - 120-122行目: Decoderの入力ラベルを`<sos>`で初期化.
  - 125-142行目: 推論メインループ.
    `max_seqlen` でループを打ち切っている点に注意してください.変則的な入力が
    与えられた場合や,モデルの学習が十分でない場合は`<eos>`が出力されず,
    推論が中々終わらないケースがあります.
    - 134行目: 時間軸に沿って応答値を積み上げていきます.
    - 136-137行目: 応答値からラベル値へ変換し,次のDecoderへの入力としています.
    - 141行目: `<eos>`が出力された場合は,推論処理を終了します.

第1.4項で説明したとおり,Encoder が Bidirectional RNN を用いている場合は,インスタンス化処理で次元数の調整を行い,_apply_encoder() 内で特徴量の整形を行っています.
また,forward()inference() で Decoder 処理前に init_dec_hstate() を呼び出して,隠れ状態を初期化している点にも注意してください.

5. 学習処理の実装

GitHub 上のスクリプトでは,この後時系列認識向けの学習処理を実装しています.
これらの内容は Encoder-Decoder 解説記事で説明しております.

ここでは,学習処理から呼び出すラッパー関数の実装だけを示します.

1
2
3
4
5
6
7
8
9
def forward(model, feature, tokens, feature_pad_mask, tokens_pad_mask):
    if isinstance(model, RNNCSLR):
        preds = model(feature,
                      tokens,
                      feature_pad_mask=feature_pad_mask,
                      tokens_pad_mask=tokens_pad_mask)
    else:
        raise NotImplementedError(f"Unknown model type:{type(model)}.")
    return preds
1
2
3
4
5
6
7
8
9
def inference(model, feature, start_id, end_id, max_seqlen=62):
    if isinstance(model, RNNCSLR):
        pred_ids, _ = model.inference(feature,
                                      start_id,
                                      end_id,
                                      max_seqlen=max_seqlen)
    else:
        raise NotImplementedError(f"Unknown model type:{type(model)}.")
    return pred_ids

forward() は学習およびバリデーション用の処理で,inference() は推論用の処理です.
それぞれの学習ループで該当の関数を呼び出すように実装しています.
詳細は Encoder-Decoder 解説記事の第4.3項をご参照ください.

6. 動作チェック

処理の実装ができましたので,動作確認をしていきます.
次のコードでデータセットから HDF5 ファイルと JSON ファイルのパスを読み込みます.

1
2
3
4
5
6
7
8
# Access check.
dataset_dir = Path("gafs_dataset_very_small")
files = list(dataset_dir.iterdir())
dictionary = [fin for fin in files if ".json" in fin.name][0]
hdf5_files = [fin for fin in files if ".hdf5" in fin.name]

print(dictionary)
print(hdf5_files)
gafs_dataset_very_small/character_to_prediction_index.json
[PosixPath('gafs_dataset_very_small/10.hdf5'), ..., PosixPath('gafs_dataset_very_small/68.hdf5')]

次のコードで辞書ファイルをロードして,認識対象のラベル数を格納します.
元々の認識対象に加えて,<sos>, <eos>, <pad> を加えている点に注意してください.
<pad> はパディング信号を示します.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
# Load dictionary.
with open(dictionary, "r") as fread:
    key2token = json.load(fread)

VOCAB = len(key2token)
# Add keywords.
key2token["<sos>"] = VOCAB
key2token["<eos>"] = VOCAB + 1
key2token["<pad>"] = VOCAB + 2
# Reset.
VOCAB = len(key2token)

次のコードで前処理を定義します.
Colab 上のスクリプトは動作の流れを示すのが主目的なので,入力側のデータ拡張は最低限にして,ラベル挿入処理を前処理に加えています.
ラベル挿入処理に関しては,Encoder-Decoder 解説記事の第4.1項をご参照ください.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
_, use_landmarks = get_fullbody_landmarks()
use_features = ["x", "y"]

trans_select_feature = SelectLandmarksAndFeature(landmarks=use_landmarks, features=use_features)
trans_repnan = ReplaceNan()
trans_norm = PartsBasedNormalization(align_mode="framewise", scale_mode="unique")
trans_insert_token = InsertTokensForS2S(sos_token=key2token["<sos>"], eos_token=key2token["<eos>"])

pre_transforms = Compose([
    trans_select_feature,
    trans_repnan,
    trans_insert_token,
    trans_norm
])

train_transforms = Compose([
    ToTensor()
])

val_transforms = Compose([
    ToTensor()
])

test_transforms = Compose([
    ToTensor()
])

次のコードで,前処理を適用した HDF5Dataset と DataLoader をインスタンス化し,データを取り出します.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
batch_size = 2
feature_shape = (len(use_features), -1, len(use_landmarks))
token_shape = (-1,)
merge_fn = partial(merge_padded_batch,
                   feature_shape=feature_shape,
                   token_shape=token_shape,
                   feature_padding_val=0.0,
                   token_padding_val=key2token["<pad>"])

dataset = HDF5Dataset(hdf5_files, pre_transforms=pre_transforms, transforms=train_transforms)

dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=merge_fn)
try:
    data = next(iter(dataloader))
    feature_origin = data["feature"]
    tokens_origin = data["token"]

    print(feature_origin.shape)
    print(tokens_origin)
except Exception as inst:
    print(inst)
torch.Size([2, 2, 177, 130])
tensor([[59, 23, 17,  0, 46, 42, 43, 32, 54, 32, 39, 32,  0, 43, 32, 45, 36, 60,
         61, 61, 61, 61],
        [59, 16, 21, 21, 23,  0, 43, 32, 33, 52, 49, 45, 52, 44,  0, 50, 51, 49,
         36, 36, 51, 60]])

出力側のラベル系列に注目してください.
<sos>: 59 で始まり,指文字のラベルが続いた後 <eos>: 60 で終了します.
ラベル長が短いデータは,長いデータに合わせて <pad>: 61 で穴埋めされます.

次のコードでモデルをインスタンス化して,動作チェックをします.
追跡点抽出の結果,入力追跡点数は 130 で,各追跡点は XY 座標値を持っていますので,入力次元数は 260 になります.
出力次元数はラベルの種類数なので 62 です.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# Define model.
# in_channels: J * C (130*2=260)
#   J: use_landmarks (130)
#   C: use_channels (2)
# out_channels: 10
in_channels = len(use_landmarks) * len(use_features)
out_channels = VOCAB
enc_hidden_channels = 64
dec_hidden_channels = 64
dec_emb_channels = 4
dec_att_dim = 64
model = RNNCSLR(
    enc_in_channels=in_channels,
    enc_hidden_channels=enc_hidden_channels,
    enc_rnn_type="gru",
    enc_num_layers=2,
    enc_activation="relu",
    enc_bidir=True,
    enc_dropout=0.1,
    enc_apply_mask=True,
    enc_proj_size=0,
    dec_in_channels=enc_hidden_channels,
    dec_hidden_channels=dec_hidden_channels,
    dec_out_channels=out_channels,
    dec_emb_channels=dec_emb_channels,
    dec_att_dim=dec_att_dim,
    dec_att_add_bias=True,
    dec_rnn_type="gru",
    dec_num_layers=2,
    dec_activation="relu",
    dec_dropout=0.1,
    dec_padding_val=key2token["<pad>"],
    dec_proj_size=0)

print(model)

# Sanity check.
sample = next(iter(dataloader))
logit = model(sample["feature"],
              tokens=sample["token"],
              feature_pad_mask=sample["feature_pad_mask"])
print(logit.shape)
RNNCSLR(
  (linear): Linear(in_features=260, out_features=64, bias=True)
  (enc_activation): ReLU()
  (encoder): RNNEncoder(
    (rnn): GRU(64, 64, num_layers=2, batch_first=True, dropout=0.1, bidirectional=True)
  )
  (decoder): BahdanauRNNDecoder(
    (emb_layer): Embedding(62, 4, padding_idx=61)
    (att_layer): SingleHeadAttention(
      (att_energy): BahdanauAttentionEnergy(
        (w_key): Linear(in_features=128, out_features=128, bias=True)
        (w_query): Linear(in_features=128, out_features=128, bias=True)
        (w_out): Linear(in_features=128, out_features=1, bias=True)
      )
    )
    (rnn): GRU(132, 128, num_layers=2, batch_first=True, dropout=0.1)
    (head): Linear(in_features=128, out_features=62, bias=True)
  )
)
torch.Size([2, 21, 62])

7. 学習と評価の実行

7.1 共通設定

では,実際に学習・評価を行います.
まずは,実験全体で共通して用いる設定値を次のコードで実装します.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Set common parameters.
batch_size = 32
load_into_ram = True
test_pid = 0
num_workers = os.cpu_count()
print(f"Using {num_workers} cores for data loading.")
lr = 3e-4
label_smoothing = 0.1
sos_token = key2token["<sos>"]
eos_token = key2token["<eos>"]
pad_token = key2token["<pad>"]
max_seqlen = 60

epochs = 50
eval_every_n_epochs = 1
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} for computation.")

train_hdf5files = [fin for fin in hdf5_files if str(test_pid) not in fin.name]
val_hdf5files = [fin for fin in hdf5_files if str(test_pid) in fin.name]
test_hdf5files = [fin for fin in hdf5_files if str(test_pid) in fin.name]

_, use_landmarks = get_fullbody_landmarks()
use_features = ["x", "y"]
Using 2 cores for data loading.
Using cuda for computation.

孤立手話単語認識の場合とほとんど同じですが,テスト推論時の最長ループ数 max_seqlen を設定しています.

次のコードで学習・バリデーション・評価処理それぞれのための DataLoader クラスを作成します.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
# Build dataloaders.
train_dataset = HDF5Dataset(train_hdf5files,
    pre_transforms=pre_transforms, transforms=train_transforms, load_into_ram=load_into_ram)
val_dataset = HDF5Dataset(val_hdf5files,
    pre_transforms=pre_transforms, transforms=val_transforms, load_into_ram=load_into_ram)
test_dataset = HDF5Dataset(test_hdf5files,
    pre_transforms=pre_transforms, transforms=test_transforms, load_into_ram=load_into_ram)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=merge_fn, num_workers=num_workers, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=merge_fn, num_workers=num_workers, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=1, collate_fn=merge_fn, num_workers=num_workers, shuffle=False)

7.2 学習・評価の実行

次のコードでモデルをインスタンス化します.
今回は時系列認識のためにカスタムした,LabelSmoothingCrossEntropyLoss をロス関数として使用します.
LabelSmoothingCrossEntropyLoss については Loss計算に関する補足記事で説明していますので,併せてご一読いただけたらうれしいです.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
model_rnn = RNNCSLR(
    enc_in_channels=in_channels,
    enc_hidden_channels=enc_hidden_channels,
    enc_rnn_type="gru",
    enc_num_layers=2,
    enc_activation="relu",
    enc_bidir=True,
    enc_dropout=0.1,
    enc_apply_mask=True,
    enc_proj_size=0,
    dec_in_channels=enc_hidden_channels,
    dec_hidden_channels=dec_hidden_channels,
    dec_out_channels=out_channels,
    dec_emb_channels=dec_emb_channels,
    dec_att_dim=dec_att_dim,
    dec_att_add_bias=True,
    dec_rnn_type="gru",
    dec_num_layers=2,
    dec_activation="relu",
    dec_dropout=0.1,
    dec_padding_val=pad_token,
    dec_proj_size=0)

print(model_rnn)

loss_fn = LabelSmoothingCrossEntropyLoss(
    ignore_indices=pad_token, reduction="mean_temporal_prior",
    label_smoothing=label_smoothing)
optimizer = torch.optim.Adam(model_rnn.parameters(), lr=lr)
RNNCSLR(
  (linear): Linear(in_features=260, out_features=64, bias=True)
  (enc_activation): ReLU()
  (encoder): RNNEncoder(
    (rnn): GRU(64, 64, num_layers=2, batch_first=True, dropout=0.1, bidirectional=True)
  )
  (decoder): BahdanauRNNDecoder(
    (emb_layer): Embedding(62, 4, padding_idx=61)
    (att_layer): SingleHeadAttention(
      (att_energy): BahdanauAttentionEnergy(
        (w_key): Linear(in_features=128, out_features=128, bias=True)
        (w_query): Linear(in_features=128, out_features=128, bias=True)
        (w_out): Linear(in_features=128, out_features=1, bias=True)
      )
    )
    (rnn): GRU(132, 128, num_layers=2, batch_first=True, dropout=0.1)
    (head): Linear(in_features=128, out_features=62, bias=True)
  )
)

次のコードで学習・評価処理を行います.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# Train, validation, and evaluation.
model_rnn.to(device)

train_losses = []
val_losses = []
test_wers = []
print("Start training.")
for epoch in range(epochs):
    print("-" * 80)
    print(f"Epoch {epoch+1}")

    train_loss, train_times = train_loop_csir_s2s(
        train_dataloader, model_rnn, loss_fn, optimizer, device,
        sos_token, eos_token,
        return_pred_times=True)
    val_loss, val_times = val_loop_csir_s2s(
        val_dataloader, model_rnn, loss_fn, device,
        sos_token, eos_token,
        return_pred_times=True)
    val_losses.append(val_loss)

    if (epoch+1) % eval_every_n_epochs == 0:
        wer, test_times = test_loop_csir_s2s(
            test_dataloader, model_rnn, device,
            sos_token, eos_token,
            return_pred_times=True,
            max_seqlen=max_seqlen)
        test_wers.append(wer)
train_losses_rnn = np.array(train_losses)
val_losses_rnn = np.array(val_losses)
test_wers_rnn = np.array(test_wers)

val_losses_rnn = np.array(val_losses_rnn)
test_wers_rnn = np.array(test_wers_rnn)
print(f"Minimum validation loss:{val_losses_rnn.min()} at {np.argmin(val_losses_rnn)+1} epoch.")
print(f"Minimum WER:{test_wers_rnn.min()} at {np.argmin(test_wers_rnn)*eval_every_n_epochs+1} epoch.")
Start training.
--------------------------------------------------------------------------------
Epoch 1
Start training.
loss:4.141368 [    0/ 2513]
Done. Time:17.759382293999977
Training performance:
 Avg loss:3.753373

Start validation.
Done. Time:1.5249640609999915
Validation performance:
 Avg loss:3.604697

Start test.
Done. Time:13.60654139999997
Test performance:
 Avg WER:91.1%
--------------------------------------------------------------------------------
...
--------------------------------------------------------------------------------
Epoch 50
Start training.
loss:2.537554 [    0/ 2513]
Done. Time:15.674790514000051
Training performance:
 Avg loss:2.567390

Start validation.
Done. Time:1.669974916999763
Validation performance: 
 Avg loss:2.669603

Start test.
Done. Time:23.861729981999815
Test performance: 
 Avg WER:87.8%

Minimum validation loss:2.6696032136678696 at 50 epoch.
Minimum WER:86.03277365443375 at 13 epoch.

冒頭で説明したとおり,Colab 上で説明用に使っているデータセットは容量が少なくて学習が安定しません.
性能評価は公式から全データセットをダウンロードして行うことをお勧めします.


今回は RNN Encoder-Decoder を用いた連続指文字認識モデルを紹介しましたが,如何でしたでしょうか?

今回の記事は正直なところ難産でした (^^;).
実験にかかる時間の見積もりを誤った上に,文量も増えてしまって最終的に 4 個の記事に分けて再構成するハメになりました.

苦労しましたが,結果的に実装上の細かい箇所もカバーできたと思うので,今後は記事が書きやすくなったと思います (そう願っています(^^;)).

今回紹介した話が,これから手話認識や深層学習を勉強してみようとお考えの方に何か参考になれば幸いです.