実践手話認識 - モデル開発編3: Encoder-Decoder を用いた時系列認識の処理

Thumbnail image
This image is generated with ChatGPT-4 Omni, and edited by the author.
作成日: 2024年09月12日(木) 00:00
最終更新日: 2024年09月18日(水) 09:47
カテゴリ: 手話言語処理
タグ:  実践手話認識 深層学習 時系列認識

Encoder-Decoder を用いた時系列認識の処理について説明します.今回はモデルではなく学習ループなどの全体的な処理の実装例を示します.

こんにちは.高山です.
実践手話認識 - モデル開発編の第三回になります.

以前の記事で,Kaggle の Google American Sign Language Fingerspelling Recognition データセットについて解説しましたので,そろそろ時系列認識も扱える状況になってきました.
そこで今回は,Encoder-Decoder を用いた時系列認識の処理について説明したいと思います.

本記事の内容は,当初は Recurrent neural network (RNN) ベースの Encoder-Decoder モデル [Bahdanau'15] の記事内で説明する予定でした.
しかし実際に書いてみると,内容がかなり多岐に渡っており文量も多くなってしまったので,分けることにしました.

認識モデル自体は別記事で紹介し,本記事では学習ループやその他の周辺的な処理について説明したいと思います.
本記事でコード例と解説は示しますが,実際に使用した実験例は認識モデルの記事で紹介しますのでご了承ください.

  • [Bahdanau'15]: D. Bahdanau, et al., "Neural Machine Translation by Jointly Learning to Align and Translate," Proc. of the ICLR, available here, 2015.

更新履歴 (大きな変更のみ記載しています)

  • 2024/09/18: カテゴリを変更しました
  • 2024/09/17
    • タイトル,タグ,サマリを更新しました
    • 系列認識 (Sequential recognition) は別の手法を連想させる可能性があるので,時系列認識 (temporal recognition) という用語を用いるようにしました

1. Encoder-Decoderの基本構成

図1に,Encoder-Decoder の基本構成を示します.

図1: Encoder-Decoderの基本構成
Encoder-Decoderの基本構成

Encoder-Decoder では,Encoder と Decoder の2種類のモデルを併用します.
手話認識における Encoder-Decoder では,Encoder は追跡点系列などの特徴系列を入力して,特徴変換を行います.
Decoder は離散数値で表現したラベル系列と Encoder の出力系列を入力して,次に出力するラベルの予測確率系列を出力します.

モデルの設計は様々ですが,全体の構成は概ね共通しています.

2 Decoderの動作

Encoder の動作は RNNを用いた孤立手話単語の記事で説明しました.
本項では Decoder の動作を主に説明します.

2.1 学習時の推論処理

図2に,Decoder の動作概略図を示します.

図2: Decoderの動作
Decoderの動作

時系列認識では,Decoder に1個のラベルを入力して,次のラベルを予測させます.
では,最初のラベルは何を入力すれば良いでしょうか?

図2に示すとおり,Decoder にはラベル系列の始まりを示すキーワードラベルを入力します.
(ここでは"\(\text{<sos>}\): start of sequence" としています)
Decoder の出力 \(\hat{\boldsymbol{p}_1}\) は次のラベルの予測確率を示しています.

同様に,"\(\text{<eos>}\): end of sequence" はラベル系列の終了を示すキーワードラベルです.
終了キーワードを含めて学習させることで,Decoder は文単位でラベル系列を学習することができます.
なお,Decoder の出力は次ラベルの予測確率なので,\(y_L = \text{<eos>}\) を入力した際の出力 \(\hat{\boldsymbol{p}_L}\) は学習には使いません.

2.2 テスト時の推論処理

図2の右側に示しているテスト時の動作について説明します.

ここでのポイントは,1番目の入力 \(y_1 = \text{<sos>}\) は決め事なので Decoder に入力できますが,2番目以降の入力は与えられていないという点です.
ではどうするかというと,Decoder が予測したラベルを 2番目以降の入力に用います.

Decoder の出力 \(\hat{\boldsymbol{p}}_l\) は次のラベルの予測確率で,内部には各ラベルへの応答値が並んでいます.
\(argmax(\cdot)\) は最大値を与えるインデクスを返す関数です.
この処理を \(\hat{\boldsymbol{p}}_l\) に適用すると,最大の応答値を与えるラベル (ラベルの数値とインデクスを揃えるので) \(\hat{y}_l\) が得られます.
得られた結果を 2番目以降の入力では用います.

なお,テスト時は Decoder が終了キーワード "\(\text{<eos>}\)" を出力したら,推論を終了します (そのように処理を組みます).

2.3 Loss計算について

今まで取り扱ってきた孤立手話単語認識に比べると,時系列認識の処理はかなり複雑です.
特に Loss の計算については注意点がいくつかあります.
本記事で述べるには少し細かな内容ですので,Loss 計算に関する補足記事に説明を載せています.
併せてご一読いただければうれしいです.

3. 学習・評価処理の構成

ここまでの説明を踏まえて学習および評価の処理フローを構成すると,図3のようになります.

図3: 学習・評価処理フロー
学習・評価処理フロー

図3に示すとおり,学習・評価処理には学習ループ,バリデーションループ,テストループの3種類の処理があります.
大まかな処理構成はシンプルな孤立手話単語モデルの記事で紹介した内容 (第7.1項をご参照ください) と同様ですが,下記の点で異なっています.

  • 推論処理が学習・バリデーションとテストで異なる.
  • 評価には単語誤り率 (WER: Word error rate) を用いる.

第2節で説明したように,Encoder-Decoder モデルでは学習時とテスト時で推論処理が異なります.
そのため,各ループでそれぞれ異なる処理を実装する必要があります.

時系列認識の評価指標は色々とありますが,音声認識や連続手話単語認識では WER がよく用いられます.
WER については WER に関する補足記事で簡単な説明をしておりますので,併せてご一読いただければうれしいです.

4. 時系列認識向け処理の実装

それではここまでの説明を踏まえて,時系列認識向け処理を実装していきます.
今まで孤立手話単語認識で実装してきた処理を活用すると,下記の点を更新すればよさそうです.

  • データ拡張 (ラベル系列へのキーワード挿入)
  • パディング処理
  • 学習処理

第2.1項で説明したとおり,Encoder-Decoder ではキーワードが挿入されたラベル系列を使用します.
データセットを直接編集しても良いのですが,データロード時に動的に挿入するようにしておくと他の形式の処理と切り替えやすくなります.
ここでは,データ拡張処理の一つとして実装したいと思います.

孤立手話単語認識ではラベル系列の長さが 1 で固定されていましたが,時系列認識ではサンプル毎に長さが異なります.
ラベル系列に対してもパディング処理を行うように処理を更新する必要があります.

学習処理に関しては第3節で説明したとおりで,時系列認識に向けた学習・評価処理を新たに実装する必要があります.

4.1 ラベル系列へのキーワード挿入処理

ラベル系列へのキーワード挿入処理は下記のようになります.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class InsertTokensForS2S():
    def __init__(self,
                 sos_token,
                 eos_token,
                 error_at_exist=False):
        self.sos_token = sos_token
        self.eos_token = eos_token
        self.error_at_exist = error_at_exist

    def check_format(self, tokens):
        insert_sos = False
        if tokens[0] != self.sos_token:
            insert_sos = True
        elif self.error_at_exist:
            message = f"The sos_token:{self.sos_token} is exist in {tokens}." \
                + "Please check the format."
            raise ValueError(message)
        insert_eos = False
        if tokens[-1] != self.eos_token:
            insert_eos = True
        elif self.error_at_exist:
            message = f"The eos_token:{self.eos_token} is exist in {tokens}." \
                + "Please check the format."
            raise ValueError(message)
        return insert_sos, insert_eos

    def __call__(self, data):

        tokens = data["token"]
        dtype = tokens.dtype

        insert_sos, insert_eos = self.check_format(tokens)
        # Insert.
        new_tokens = []
        if insert_sos:
            new_tokens.append(self.sos_token)
        new_tokens += tokens.tolist()
        if insert_eos:
            new_tokens.append(self.eos_token)
        new_tokens = np.array(new_tokens, dtype=dtype)
        data["token"] = new_tokens
        return data
【コード解説】
- 引数
  - sos_token: `<sos>`のラベル値.
  - eos_token: `<eos>`のラベル値.
  - error_at_exist: Trueの場合,データセットオリジナルのラベル値に`<sos>`か
    `<eos>`が含まれていたら例外を投げます.
    Falseの場合は,ラベル値の挿入をスキップして処理を継続します.
- 10-25行目: データフォーマットのチェック処理.
- 29-42行目: メイン処理.ラベル系列の先頭と末尾に,それぞれ`<sos>`と`<eos>`を
  挿入します.

4.2 パディング処理

パディング処理は DataLoader に与える collate_fn に実装します.
具体的には下記のようになります.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def merge_padded_batch(batch,
                       feature_shape,
                       token_shape,
                       feature_padding_val=0,
                       token_padding_val=0):
    feature_batch = [sample["feature"] for sample in batch]
    token_batch = [sample["token"] for sample in batch]

    # ==========================================================
    # Merge feature.
    # ==========================================================
    # `[B, C, T, J]`
    merged_shape = [len(batch), *feature_shape]
    # Use maximum frame length in a batch as padded length.
    if merged_shape[2] == -1:
        tlen = max([feature.shape[1] for feature in feature_batch])
        merged_shape[2] = tlen
    merged_feature = merge(feature_batch, merged_shape, padding_val=feature_padding_val)

    # ==========================================================
    # Merge token.
    # ==========================================================
    # `[B, L]`
    merged_shape = [len(batch), *token_shape]
    # Use maximum token length in a batch as padded length.
    if merged_shape[1] == -1:
        tlen = max([token.shape[0] for token in token_batch])
        merged_shape[1] = tlen
    merged_token = merge(token_batch, merged_shape, padding_val=token_padding_val)

    # Generate padding mask.
    # Pad: 0, Signal: 1
    # The frames which all channels and landmarks are equals to padding value
    # should be padded.
    feature_pad_mask = merged_feature == feature_padding_val
    feature_pad_mask = torch.all(feature_pad_mask, dim=1)
    feature_pad_mask = torch.all(feature_pad_mask, dim=-1)
    feature_pad_mask = torch.logical_not(feature_pad_mask)
    token_pad_mask = torch.logical_not(merged_token == token_padding_val)

    retval = {
        "feature": merged_feature,
        "token": merged_token,
        "feature_pad_mask": feature_pad_mask,
        "token_pad_mask": token_pad_mask}
    return retval

24-29行目でラベル系列に対してもパディング処理を行うように処理を更新しています.
その他に関しては,データセットのアクセス処理で説明した内容 (第5.3項をご参照ください) と同様です.

4.3 学習処理

最後に,Encoder-Decoder モデル用の学習処理を実装します.
第3節で説明したように,Encoder-Decoder では学習と推論で異なる処理,インタフェースになる場合があります.
これはモデルが異なる場合も同じです.
今後モデルが増える度に処理を大きく変更するのは大変なので,下記に示すようなラッパー関数を用意しておきます.

1
2
def forward(model, feature, tokens, feature_pad_mask, tokens_pad_mask):
    pass
1
2
def inference(model, feature, start_id, end_id, max_seqlen=62):
    pass

forward() は学習およびバリデーション用の処理で,inference() は推論用の処理です.
それぞれの学習ループで該当の関数を呼び出します.
モデルが変更または追加された場合はこれらのラッパー関数を更新すればよく,学習処理は変更せずに済みます.

なお,具体的な処理は認識モデルの紹介時に実装します.

次にラベル系列が意図通りになっているかチェックする関数を実装します.
必須では無いですが,データフォーマットが原因で学習結果がおかしい場合は見つけるのに苦労する場合が多いです.
やっておいた方が良さそうなチェックはなるべく実装するようにしています.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
def check_tokens_format(tokens, tokens_pad_mask, start_id, end_id):
    # Check token's format.
    end_indices0 = np.arange(len(tokens))
    end_indices1 = tokens_pad_mask.sum(dim=-1).detach().cpu().numpy() - 1
    message = "The start and/or end ids are not included in tokens. " \
        f"Please check data format. start_id:{start_id}, " \
        f"end_id:{end_id}, enc_indices:{end_indices1}, tokens:{tokens}"
    ref_tokens = tokens.detach().cpu().numpy()
    assert (ref_tokens[:, 0] == start_id).all(), message
    assert (ref_tokens[end_indices0, end_indices1] == end_id).all(), message

最後に,学習,バリデーション,テストループを実装していきます.

学習ループ

まずは,次のコードで学習ループを実装します.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def train_loop_csir_s2s(dataloader,
                        model,
                        loss_fn,
                        optimizer,
                        device,
                        start_id,
                        end_id,
                        return_pred_times=False):
    num_batches = len(dataloader)
    train_loss = 0
    size = len(dataloader.dataset)

    # Collect prediction time.
    pred_times = []

    # Switch to training mode.
    model.train()
    # Main loop.
    print("Start training.")
    start = time.perf_counter()
    for batch_idx, batch_sample in enumerate(dataloader):
        feature = batch_sample["feature"]
        feature_pad_mask = batch_sample["feature_pad_mask"]
        tokens = batch_sample["token"]
        tokens_pad_mask = batch_sample["token_pad_mask"]

        check_tokens_format(tokens, tokens_pad_mask, start_id, end_id)

        feature = feature.to(device)
        feature_pad_mask = feature_pad_mask.to(device)
        tokens = tokens.to(device)
        tokens_pad_mask = tokens_pad_mask.to(device)

        frames = feature.shape[-2]

        # Predict.
        pred_start = time.perf_counter()
        preds = forward(model, feature, tokens,
                        feature_pad_mask, tokens_pad_mask)
        pred_end = time.perf_counter()
        pred_times.append([frames, pred_end - pred_start])

        # Compute loss.
        # Preds do not include <start>, so skip that of tokens.
        loss = 0
        if isinstance(loss_fn, nn.CrossEntropyLoss):
            for t_index in range(1, tokens.shape[-1]):
                pred = preds[:, t_index-1, :]
                token = tokens[:, t_index]
                loss += loss_fn(pred, token)
            loss /= tokens.shape[-1]
        # LabelSmoothingCrossEntropyLoss
        else:
            # `[N, T, C] -> [N, C, T]`
            preds = preds.permute([0, 2, 1])
            # Remove prediction after the last token.
            if preds.shape[-1] == tokens.shape[-1]:
                preds = preds[:, :, :-1]
            loss = loss_fn(preds, tokens[:, 1:])

        # Back propagation.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        # Print current loss per 100 steps.
        if batch_idx % 100 == 0:
            loss = loss.item()
            steps = batch_idx * len(feature)
            print(f"loss:{loss:>7f} [{steps:>5d}/{size:>5d}]")
    print(f"Done. Time:{time.perf_counter()-start}")
    # Average loss.
    train_loss /= num_batches
    print("Training performance: \n",
          f"Avg loss:{train_loss:>8f}\n")
    pred_times = np.array(pred_times)
    retval = (train_loss, pred_times) if return_pred_times else train_loss
    return retval
【コード解説】
- 引数
  - dataloader: DataLoaderクラスのインスタンス
  - model: 認識モデルのインスタンス
  - loss_fn: Loss関数のインスタンス
  - optimizer: モデルのパラメータ制御クラスのインスタンス
  - device: 計算処理を行うデバイスを示す文字列 ("cpu"や"cuda"など)
  - start_id: `<sos>`のラベル値
  - end_id: `<eos>`のラベル値
  - return_pred_times: Trueの場合,推論処理時間を計測して返します.
- 9行目: データ数を取得.この値は学習の進捗を表示するために使用します.
- 17行目: モデルを学習モードに切り替え
- 21-72行目: 学習ループ
  - 22-32行目: バッチデータをロードしてデバイス (CPUやGPU) に転送.
    27行目で`tokens`のフォーマットチェックを行っています.
  - 37-41行目: 推論処理.
  - 45-59行目: Loss値の算出.
    `loss_fn`に与える`pred`と`token`間で,インデクスがズレている点に注意してください.
    `loss_fn`がnn.CrossEntropyLossの場合は,ループで和を求めて最後にバッチの
    系列長で割ります.
    これは,LabelSmoothingCrossEntropyLossで`reduction=mean_batch_prior` とした
    場合と同等です.
  - 62-64行目: モデルのパラメータを更新
  - 69-72行目: 学習の進捗状況を表示
  - 73-79行目: 各種の情報を表示と返り値の整形.

バリデーションループ

次のコードでバリデーションループを実装します.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def val_loop_csir_s2s(dataloader,
                      model,
                      loss_fn,
                      device,
                      start_id,
                      end_id,
                      return_pred_times=False):
    num_batches = len(dataloader)
    val_loss = 0

    # Collect prediction time.
    pred_times = []

    # Switch to evaluation mode.
    model.eval()
    # Main loop.
    print("Start validation.")
    start = time.perf_counter()
    with torch.no_grad():
        for batch_idx, batch_sample in enumerate(dataloader):
            feature = batch_sample["feature"]
            feature_pad_mask = batch_sample["feature_pad_mask"]
            tokens = batch_sample["token"]
            tokens_pad_mask = batch_sample["token_pad_mask"]

            check_tokens_format(tokens, tokens_pad_mask, start_id, end_id)

            feature = feature.to(device)
            feature_pad_mask = feature_pad_mask.to(device)
            tokens = tokens.to(device)
            tokens_pad_mask = tokens_pad_mask.to(device)

            frames = feature.shape[-2]

            # Predict.
            pred_start = time.perf_counter()
            preds = forward(model, feature, tokens,
                            feature_pad_mask, tokens_pad_mask)
            pred_end = time.perf_counter()
            pred_times.append([frames, pred_end - pred_start])

            # Compute loss.
            # Preds do not include <start>, so skip that of tokens.
            loss = 0
            if isinstance(loss_fn, nn.CrossEntropyLoss):
                for t_index in range(1, tokens.shape[-1]):
                    pred = preds[:, t_index-1, :]
                    token = tokens[:, t_index]
                    loss += loss_fn(pred, token)
                loss /= tokens.shape[-1]
            # LabelSmoothingCrossEntropyLoss
            else:
                # `[N, T, C] -> [N, C, T]`
                preds = preds.permute([0, 2, 1])
                # Remove prediction after the last token.
                if preds.shape[-1] == tokens.shape[-1]:
                    preds = preds[:, :, :-1]
                loss = loss_fn(preds, tokens[:, 1:])

            val_loss += loss.item()
    print(f"Done. Time:{time.perf_counter()-start}")

    # Average loss.
    val_loss /= num_batches
    print("Validation performance: \n",
          f"Avg loss:{val_loss:>8f}\n")
    pred_times = np.array(pred_times)
    retval = (val_loss, pred_times) if return_pred_times else val_loss
    return retval

モデルのパラメータ更新が無いので,

  • optimizerを渡していない (なので更新処理がない).
  • 15行目で model.eval() としてモデルを評価モードにしている.
  • 19行目で with torch.no_grad() として勾配値を無効にしている.

以外は学習ループとほぼ同じなので,細かな説明は割愛させていただきます.

テストループ

次のコードでテストループを実装します.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def test_loop_csir_s2s(dataloader,
                       model,
                       device,
                       start_id,
                       end_id,
                       return_pred_times=False,
                       max_seqlen=62):
    size = len(dataloader.dataset)
    total_wer = 0

    # Collect prediction time.
    pred_times = []

    # Switch to evaluation mode.
    model.eval()
    # Main loop.
    print("Start test.")
    start = time.perf_counter()
    with torch.no_grad():
        for batch_idx, batch_sample in enumerate(dataloader):
            feature = batch_sample["feature"]
            tokens = batch_sample["token"]
            tokens_pad_mask = batch_sample["token_pad_mask"]

            check_tokens_format(tokens, tokens_pad_mask, start_id, end_id)

            feature = feature.to(device)
            tokens = tokens.to(device)
            tokens_pad_mask = tokens_pad_mask.to(device)

            frames = feature.shape[-2]

            # Predict.
            pred_start = time.perf_counter()
            pred_ids = inference(model, feature, start_id, end_id, max_seqlen=max_seqlen)
            pred_end = time.perf_counter()
            pred_times.append([frames, pred_end - pred_start])

            # Compute WER.
            # <sos> and <eos> should be removed because they may boost performance.
            # print(tokens)
            # print(pred_ids)
            tokens = tokens[0, 1:-1]
            # pred_ids = pred_ids[0, 1:-1]
            pred_ids = [pid for pid in pred_ids[0] if pid not in [start_id, end_id]]
            ref_length = len(tokens)
            wer = edit_distance(tokens, pred_ids)
            wer /= ref_length
            total_wer += wer
    print(f"Done. Time:{time.perf_counter()-start}")

    # Average WER.
    awer = total_wer / size * 100
    print("Test performance: \n",
          f"Avg WER:{awer:>0.1f}%\n")
    pred_times = np.array(pred_times)
    retval = (awer, pred_times) if return_pred_times else awer
    return retval

こちらも大枠はバリデーションループと似ているので細かな説明は割愛させていただきます.
下記の点に注意してください.

  • 36行目: テスト用推論処理 inference() を呼び出しています.
  • 44-53行目: WER を計算していますが,<sos><eos>は評価に含めない点に注意してください. 例えば,<sos>は必ず正解なので評価に含めた場合は実際よりも良い認識性能が出てしまいます.

今回は Encoder-Decoder を用いた時系列認識の処理について説明しましたが,如何でしたでしょうか?
と言っても今回は実際に動くものをお見せできていないのでピンと来ないかもしれませんね.
認識モデルの記事をなるべく早く出すようにしたいと思います.

Encoder-Decoder は孤立手話単語認識に比べると,実装が複雑で細かな注意点も多いです.
ハマりそうな箇所は補足記事などで随時カバーしていきますので,そちらもご一読いただければと思います.

今回紹介した話が,これから手話認識や深層学習を勉強してみようとお考えの方に何か参考になれば幸いです.