【コード解説・PyTorch】手話認識入門16 - 様々な改善手法7: ノイズ付加によるデータ拡張

著者: Natsuki Takayama
作成日: 2024年06月27日(木) 00:00
最終更新日: 2024年06月27日(木) 14:17
カテゴリ: コンピュータビジョン

こんにちは.高山です.
先日の記事で告知しました手話入門記事の第十六回になります.

今回は手話動画から抽出した追跡点系列に対して,データ拡張を施すことで認識性能を改善する手法を紹介します.
前回と同様に,今回も劣化させたデータを学習することで認識モデルの頑健性を向上させるアプローチです.
具体的には,手話中の追跡点系列をランダムにマスキングして,頑健性を向上させます.

図1にマスキングの適用例を示します.

(a): マスキング前
(b): 身体部位
(c): 個々の追跡点
(d): フレーム
(e): 画像空間
マスキングの適用例

それぞれのマスキング処理を簡単に説明しますと,下記のようになります.

  • (b): ランダムに部位をマスキング (例では身体)
  • (c): ランダムに個々の追跡点をマスキング
  • (d): ランダムにフレームをマスキング (例では冒頭)
  • (e): ランダムに画像の特定範囲をマスキング (例では唇付近)

「こんなことをしてしまって大丈夫なの?」と感じるかもしれませんが,画像処理関連の深層学習ではマスキングはよく用いられます [DeVries'17, Singh'17, Zhong'20, Chen'20, Li'20].

この処理は,入力データの一部を欠落させることでモデルが特定の箇所に依存し過ぎないようにすることが目的です.
手話認識の文脈では,特定の部位やフレームへの依存性を緩和することを狙います.

正直なところ,手話認識の分野ではマスキングを用いたデータ拡張に対して,あまり検討が進んでいないように思えます (データ拡張全般としてもそうですが...).
今回は高山が今まで見たことがあるマスキング処理を実装してみました.
KaggleのGoogle Isolated Sign Language Recognition データセット に適用してみて,効果を検証してみたいと思います.

今回解説するスクリプトはGitHub上に公開しています
複数の実験を行っている都合で,CPUで動かした場合は結構時間がかるのでご注意ください.

  • [DeVries'17]: T. DeVries, et al., "Improved Regularization of Convolutional Neural Networks with Cutout," arXiV: 1708.04552, available here, 2017.
  • [Singh'17]: K. K. Singh, et al., "Hide-and-Seek: Forcing a Network to be Meticulous for Weakly-supervised Object and Action Localization," Proc. of the ICCV, available here, 2017.
  • [Zhong'20]: Z. Zhong, et al., "Random Erasing Data Augmentation," Proc. of the AAAI, available here, 2020.
  • [Chen'20]: P. Chen, et al., "GridMask Data Augmentation," arXiV: 2001.04086, available here, 2020.
  • [Li'20]: P. Li, et al., "FenceMask: A Data Augmentation Approach for Pre-extracted Image Features," arXiV: 2006.07877, available here, 2020.

1. 概要

1.1 今回説明する内容

実装の詳細に先立って,今回紹介する内容の概要を説明したいと思います.
図2は,先日の記事で説明した機械学習モデル構築のワークフローの何処が今回の説明箇所に該当するかを示しています.

図2: 学習モデル構築のワークフローと紹介箇所
学習モデル構築のワークフローと紹介箇所

今回説明するデータ拡張は,学習用データセットからデータを取り出す際に行う,前処理に該当します.
特徴量エンジニアリングとデータ拡張の関係については,第一回の記事 (第1.2項と第1.3項) または第十一回の記事 (第1.1項) をご参照ください.

  • [Amershi'19]: S. Amershi, et al., "Software Engineering for Machine Learning: A Case Study," IEEE/ACM ICSE-SEIP 2019.

1.2 マスキングの処理工程

図3に,追跡点系列のマスキング処理工程を示します.

図3: マスキング処理工程
マスキング処理工程

処理構成はパラメータ算出とマスキングの適用でシンプルですが,実装形態はマスキングの方法によって異なります.
図3のマスキング後の特徴マップに示すように,マスキングされた箇所の特徴量はゼロクリアされます.

記事の冒頭で少し紹介しましたが,改めて今回実装するマスキング処理を図4に示します.

図4: 様々なマスキング処理
様々なマスキング処理

身体部位毎にマスキング

図4(a) は部位単位でランダムにマスキングする処理です.
特徴マップ上では水平方向にマスキングする処理に相当します.
この例では身体追跡点をマスキングしています.

個々の追跡点をマスキング

図4(b) は個々の追跡点をランダムにマスキングする処理です.
特徴マップに示すとおり,この処理では虫食いのような特徴マップが生成されます.
なお,この例では顔の追跡点はマスキングされないようにしています.

時間フレームでマスキング

図4(c) は時間フレームをランダムにマスキングする処理です.
特徴マップ上では鉛直方向にマスキングする処理に相当します.

座標空間上でマスキング

図4(a) から (c) は特徴マップ上 (時間-追跡点空間) のマスキング処理です.
それに対して,図4(d) は座標空間上のマスキング処理になります.
このタイプのマスキング処理では空間上にマスキング領域を設定して,該当領域内 (または領域外) にある追跡点をマスキングします.
図4(d) では CutOut [DeVries'17] や Random Erasing [Zhong'20] と同様に,ランダムな矩形領域をマスキング領域として設定しています.

今回は上記4種類のマスキング処理をデータ拡張として実装して,性能を検証していきたいと思います.

  • [DeVries'17]: T. DeVries, et al., "Improved Regularization of Convolutional Neural Networks with Cutout," arXiV: 1708.04552, available here, 2017.
  • [Zhong'20]: Z. Zhong, et al., "Random Erasing Data Augmentation," Proc. of the AAAI, available here, 2020.

1.3 先に結果

第2節以降では,いつも通り実装の紹介をしながら実験結果をお見せします.
コード紹介記事の方針として記事単体で全処理が分かるように書いており,少し長いので結果を先にお見せしたいと思います.

頻度の多い10単語を学習させた結果

図5は,マスキング種別毎のValidation Lossと認識率の推移を示しています.

図5: 認識性能比較結果 (Top10)
認識性能比較結果 (Top10)

横軸は学習・評価ループの繰り返し数 (Epoch) を示します.
縦軸はそれぞれの評価指標を示します.

各線の色と実験条件の関係は次のとおりです.

  • 青線 (Default): Pre-LN構成のTransformer
  • 橙線 (+ D-Parts): 部位毎のマスキング
  • 緑線 (+ D-Joints): 個々の追跡点のマスキング
  • 赤線 (+ D-Temporal): フレームのマスキング
  • 紫線 (+ D-Spatial-O): 座標空間での矩形マスキング
  • 茶線 (+ D-Spatial-W): D-Spatial-O の反転マスキング

デフォルトのモデルには,第九回の記事で紹介した,Pre-LN構成のTransformerモデルを用います.

(比較対象が多いので見辛くて申し訳ないですが) ロスの挙動から過学習は抑制できていそうですが,やや不安定な感じです.
認識性能はあまり差がないように見えますね.

250単語を学習させた結果

データが少なくて学習が不安定になっている可能性がありますので,全データ (250単語) を学習させた場合の挙動を図6に示します.
なお,こちらの実験はメモリや処理時間の都合でColab上では実行が難しいので,ローカル環境で行いました.

データの分割方法やパラメータは10単語のときと同じです.
ただし,学習時間を短縮するためにバッチ数は256に設定しています.
(本来はバッチ数を変えた場合は学習率も調整した方が良いのですが,今回はママで実験を行っています)

図6: 認識性能比較結果 (Full)
認識性能比較結果 (Full)

全データを学習させた結果では,部位毎のマスキングを用いた場合 (D-Parts) は認識性能が改善しています.

D-Joints, D-Temporal, D-Spatial-O に関しては,認識性能がやや悪化していますが,微妙な差ですのでパラメータ次第では結果が変わるかもしれません.

一方,D-Spatial-Wでは大きく認識性能が悪化するという結果になりました.
D-Spatial-Wでは,座標空間の外縁部がマスキングされやすい傾向があります.
10単語の場合は性能が保てていたことを考えると,追加した単語の中に外縁部のマスキングが影響されやすい単語が多く含まれていた可能性があります.

なお,今回の実験では話を簡単にするために,実験条件以外のパラメータは固定にし,乱数の制御もしていません.
必ずしも同様の結果になるわけではないので,ご了承ください.


2. 前準備

2.1 データセットのダウンロード

ここからは実装方法の説明をしていきます.
まずは,前準備としてGoogle Colabにデータセットをアップロードします. ここの工程はこれまでの記事と同じですので,既に行ったことのある方は第2.3項まで飛ばしていただいて構いません.

まず最初に,データセットの格納先からデータをダウンロードし,ご自分のGoogle driveへアップロードしてください.

次のコードでGoogle driveをColabへマウントします.
Google Driveのマウント方法については,補足記事にも記載してあります.

1
2
3
from google.colab import drive

drive.mount("/content/drive")

ドライブ内のファイルをColabへコピーします.
パスはアップロード先を設定する必要があります.

# Copy to local.
!cp [path_to_dataset]/gislr_dataset_top10.zip gislr_top10.zip

データセットはZIP形式になっているので unzip コマンドで解凍します.

!unzip gislr_top10.zip
Archive:  gislr_top10.zip
   creating: dataset_top10/
  inflating: dataset_top10/16069.hdf5
  ...
  inflating: dataset_top10/sign_to_prediction_index_map.json

成功すると dataset_top10 以下にデータが解凍されます.
HDF5ファイルはデータ本体で,手話者毎にファイルが別れています.
JSONファイルは辞書ファイルで,TXTファイルは本データセットのライセンスです.

!ls dataset_top10
16069.hdf5  25571.hdf5  29302.hdf5  36257.hdf5  49445.hdf5  62590.hdf5
18796.hdf5  26734.hdf5  30680.hdf5  37055.hdf5  53618.hdf5  LICENSE.txt
2044.hdf5   27610.hdf5  32319.hdf5  37779.hdf5  55372.hdf5  sign_to_prediction_index_map.json
22343.hdf5  28656.hdf5  34503.hdf5  4718.hdf5   61333.hdf5

単語辞書には単語名と数値の関係が10単語分定義されています.

!cat dataset_top10/sign_to_prediction_index_map.json
{
    "listen": 0,
    "look": 1,
    "shhh": 2,
    "donkey": 3,
    "mouse": 4,
    "duck": 5,
    "uncle": 6,
    "hear": 7,
    "pretend": 8,
    "cow": 9
}

ライセンスはオリジナルと同様に,CC-BY 4.0 としています.

!cat dataset_top10/LICENSE.txt
The dataset provided by Natsuki Takayama (Takayama Research and Development Office) is licensed under CC-BY 4.0.
Author: Copyright 2024 Natsuki Takayama
Title: GISLR Top 10 dataset
Original licenser: Deaf Professional Arts Network and the Georgia Institute of Technology
Modification
- Extract 10 most frequent words.
- Packaged into HDF5 format.

次のコードでサンプルを確認します.
サンプルは辞書型のようにキーバリュー形式で保存されており,下記のように階層化されています.

- サンプルID (トップ階層のKey)
  |- feature: 入力特徴量で `[C(=3), T, J(=543)]` 形状.C,T,Jは,それぞれ特徴次元,フレーム数,追跡点数です.
  |- token: 単語ラベル値で `[1]` 形状.0から9の数値です.
1
2
3
4
5
6
7
8
9
with h5py.File("dataset_top10/16069.hdf5", "r") as fread:
    keys = list(fread.keys())
    print(keys)
    group = fread[keys[0]]
    print(group.keys())
    feature = group["feature"][:]
    token = group["token"][:]
    print(feature.shape)
    print(token)
['1109479272', '11121526', ..., '976754415']
<KeysViewHDF5 ['feature', 'token']>
(3, 23, 543)
[1]

2.2 モジュールのダウンロード

次に,過去の記事で実装したコードをダウンロードします.
本項は前回までに紹介した内容と同じですので,飛ばしていただいても構いません. コードはGithubのsrc/modules_gislrにアップしてあります (今後の記事で使用するコードも含まれています).

まず,下記のコマンドでレポジトリをダウンロードします.
(目的のディレクトリだけダウンロードする方法はまだ調査中です(^^;))

!wget https://github.com/takayama-rado/trado_samples/archive/master.zip
--2024-01-21 11:01:47--  https://github.com/takayama-rado/trado_samples/archive/master.zip
Resolving github.com (github.com)... 140.82.112.3
...
2024-01-21 11:01:51 (19.4 MB/s) - ‘master.zip’ saved [75710869]

ダウンロードしたリポジトリを解凍します.

!unzip -o master.zip -d master
Archive:  master.zip
641b06a0ca7f5430a945a53b4825e22b5f3b8eb6
   creating: master/trado_samples-main/
  inflating: master/trado_samples-main/.gitignore
  ...

モジュールのディレクトリをカレントディレクトリに移動します.

!mv master/trado_samples-main/src/modules_gislr .

他のファイルは不要なので削除します.

!rm -rf master master.zip gislr_top10.zip
!ls
dataset_top10 drive modules_gislr  sample_data

2.3 モジュールのロード

主要な処理の実装に先立って,下記のコードでモジュールをロードします.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import copy
import json
import math
import os
import random
import sys
from functools import partial
from inspect import signature
from pathlib import Path
from typing import (
    Any,
    Dict,
    List
)

# Third party's modules
import cv2

import numpy as np

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import (
    DataLoader)

from torchvision.transforms import Compose

# Local modules
sys.path.append("modules_gislr")
from modules_gislr.dataset import (
    HDF5Dataset,
    merge_padded_batch)
from modules_gislr.defines import (
    get_fullbody_landmarks
)
from modules_gislr.layers import (
    Identity,
    GPoolRecognitionHead,
    TransformerEnISLR
)
from modules_gislr.train_functions import (
    test_loop,
    val_loop,
    train_loop
)
from modules_gislr.transforms import (
    PartsBasedNormalization,
    ReplaceNan,
    SelectLandmarksAndFeature,
    ToTensor
)
【コード解説】
- 標準モジュール
  - copy: データコピーライブラリ.Transformerブロック内でEncoder層をコピーするために使用します.
  - json: JSONファイル制御ライブラリ.辞書ファイルのロードに使用します.
  - math: 数学計算処理ライブラリ
  - os: システム処理ライブラリ
  - random: ランダム値生成ライブラリ
  - sys: Pythonインタプリタの制御ライブラリ.
    今回はローカルモジュールに対してパスを通すために使用します.
  - functools: 関数オブジェクトを操作するためのライブラリ.
    今回はDataLoaderクラスに渡すパディング関数に対して設定値をセットするために使用します.
  - inspect.signature: オブジェクトの情報取得ライブラリ.
  - pathlib.Path: オブジェクト指向のファイルシステム機能.
    主にファイルアクセスに使います.osモジュールを使っても同様の処理は可能です.
    高山の好みでこちらのモジュールを使っています(^^;).
  - typing: 関数などに型アノテーションを行う機能.
    ここでは型を忘れやすい関数に付けていますが,本来は全てアノテーションをした方が良いでしょう(^^;).
- 3rdパーティモジュール
  - cv2: 画像処理ライブラリ.
    今回はマスキング範囲を時間方向に拡張する処理に用います.
  - numpy: 行列演算ライブラリ
  - torch: ニューラルネットワークライブラリ
  - torchvision: PyTorchと親和性が高い画像処理ライブラリ.
    今回はDatasetクラスに与える前処理をパッケージするために用います.
- ローカルモジュール: sys.pathにパスを追加することでロード可能
  - dataset: データセット操作用モジュール
  - defines: 各部位の追跡点,追跡点間の接続関係,およびそれらへのアクセス処理を
    定義したモジュール
  - layers: ニューラルネットワークのモデルやレイヤモジュール
  - transforms: 入出力変換処理モジュール
  - train_functions: 学習・評価処理モジュール

3. マスキング処理の実装

マスキング処理は,処理方法毎に実装方法が異なります.
今回は,処理方法毎にクラスを定義して個別に呼び出せるように実装したいと思います.

部位毎のマスキング

身体部位毎にマスキングする処理の実装は下記のようになります.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class RandomDropParts():
    def __init__(self,
                 apply_ratio,
                 face_head=0,
                 face_num=len(USE_FACE),
                 lhand_head=len(USE_FACE),
                 lhand_num=len(USE_LHAND),
                 pose_head=len(USE_FACE)+len(USE_LHAND),
                 pose_num=len(USE_POSE),
                 rhand_head=len(USE_FACE)+len(USE_LHAND)+len(USE_POSE),
                 rhand_num=len(USE_RHAND),
                 relative_drop_freq=None):
        targets = []
        if face_head is not None:
            targets.append("face")
        if lhand_head is not None:
            targets.append("lhand")
        if pose_head is not None:
            targets.append("pose")
        if rhand_head is not None:
            targets.append("rhand")

        if relative_drop_freq is not None:
            message = f"relative_drop_freq:{relative_drop_freq}, targets:{targets}"
            assert len(relative_drop_freq) == len(targets), message
            temp = np.array(relative_drop_freq)
            temp = np.cumsum(temp / temp.sum())
            self.relative_drop_ratio = temp * apply_ratio
        else:
            self.relative_drop_ratio = np.array([0.25]*len(targets))

        self.apply_ratio = apply_ratio
        self.face_joints = np.arange(face_head, face_head + face_num) if "face" in targets \
            else None
        self.lhand_joints = np.arange(lhand_head, lhand_head + lhand_num) if "lhand" in targets \
            else None
        self.pose_joints = np.arange(pose_head, pose_head + pose_num) if "pose" in targets \
            else None
        self.rhand_joints = np.arange(rhand_head, rhand_head + rhand_num) if "rhand" in targets \
            else None
        self.targets = targets

    def __call__(self,
                 data: Dict[str, Any]) -> Dict[str, Any]:
        rval = random.random()
        if rval > self.apply_ratio:
            return data

        lower_bound = 0
        feature = data["feature"]
        mask = (feature == 0).all(axis=0)
        mask = np.bitwise_not(mask)
        mask = np.expand_dims(mask, 0)
        for target, ratio in zip(self.targets, self.relative_drop_ratio):
            if target == "face":
                target_joints = self.face_joints
            elif target == "lhand":
                target_joints = self.lhand_joints
            elif target == "pose":
                target_joints = self.pose_joints
            elif target == "rhand":
                target_joints = self.rhand_joints
            if rval >= lower_bound and rval < ratio:
                mask[:, :, target_joints] = 0.0
                break
            else:
                lower_bound = ratio
        if not (mask == 0.0).all():
            feature *= mask

        data["feature"] = feature
        return data
【コード解説】
- 引数
  - apply_ratio: データ拡張の適用確率.
  - face_head: 顔追跡点の先頭インデクス.
    `None` の場合はマスキング対象外になります.
  - face_num: 顔追跡点数.
  - lhand_head: 左手の先頭インデクス.
  - lhand_num: 左手追跡点点数.
    `None` の場合はマスキング対象外になります.
  - pose_head: 身体追跡点の先頭インデクス.
  - pose_num: 身体追跡点追跡点数.
    `None` の場合はマスキング対象外になります.
  - rhand_head: 右手追跡点の先頭インデクス.
  - rhand_num: 右手追跡点数.
    `None` の場合はマスキング対象外になります.
  - relative_drop_freq: マスキング部位の相対頻度.
    例えば,`[1, 2, 3, 4]` の場合は,右手が顔の4倍マスキングされやすくなります.
    `None` の場合は,マスキング対象の部位を等確率でマスキングします.
- 13-41行目: 初期化処理
  - 13-21行目: マスキング対象の部位を決定.
  - 23-30行目: 部位間のマスキング頻度を算出.
  - 33-40行目: マスキング用の追跡点インデクスを部位毎に算出.
- 43-72行目: マスキング処理
  - 45-47行目: 乱数を生成し,`apply_ratio` 以上だった場合は何もせずに値を返す.
  - 51-53行目: マスキング配列を生成し,追跡失敗フレームのマスク値をセット.
  - 54-67行目: マスキング対象の部位をランダムに選択肢,マスキング.
    `relative_drop_ratio` を考慮しながら,部位毎にマスキングをする確率範囲を
    定義することで分岐処理を実装しています.
  - 68-69行目: マスキングの結果,特徴量を除去しすぎた場合は処理をスキップ.

個々の追跡点のマスキング

個々の追跡点をマスキングする処理の実装は下記のようになります.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class RandomDropJoints():
    def __init__(self,
                 apply_ratio,
                 drop_ratio,
                 undrop_joints=None,
                 drop_tsize=1):
        self.apply_ratio = apply_ratio
        self.drop_ratio = drop_ratio
        self.undrop_joints = undrop_joints
        assert drop_tsize % 2 == 1
        self.kernel = np.array([[1] * drop_tsize], dtype=np.uint8).T

    def __call__(self,
                 data: Dict[str, Any]) -> Dict[str, Any]:

        if random.random() > self.apply_ratio:
            return data

        feature = data["feature"]
        mask = (feature == 0).all(axis=0)
        mask = np.bitwise_not(mask)
        undrop = np.random.random(mask.shape)
        undrop[undrop >= self.drop_ratio] = 1.0
        undrop[undrop < self.drop_ratio] = 0.0

        temp = undrop.astype(np.uint8)
        temp = cv2.erode(temp, self.kernel)
        undrop = np.expand_dims(temp, axis=0)

        if self.undrop_joints is not None:
            undrop[:, :, self.undrop_joints] = 1.0
        mask = np.bitwise_and(mask, undrop)
        if not (mask == 0.0).all():
            feature *= mask

        data["feature"] = feature
        return data
【コード解説】
- 引数
  - apply_ratio: データ拡張の適用確率.
  - drop_ratio: マスキングの比率 [0, 1].
  - undrop_joints: マスキング対象外の追跡点インデクス.
  - drop_tsize: マスキング箇所の時間長.
    `drop_tsize > 1` の場合,マスキング箇所を時間軸方向に拡張します.
    値は奇数で指定する必要があります.
- 7-11行目: 初期化処理
- 13-37行目: マスキング処理
  - 16-17行目: 乱数を生成し,`apply_ratio` 以上だった場合は何もせずに値を返す.
  - 20-21行目: マスキング配列を生成し,追跡失敗フレームのマスク値をセット.
  - 22-24行目: マスキング配列と同形状のランダム配列を生成し,`drop_ratio`
    未満の箇所のマスキング箇所としてセット.
  - 26-28行目: 収縮処理 (0値 の膨張) を用いて,マスキング箇所を時間軸方向に拡張.
  - 30-31行目: マスキング対象外の追跡点に 1 をセット.
  - 32行目: 追跡失敗フレームのマスキング配列と積をとり,マスキング配列を生成.
  - 33-34行目: マスキングの適用.マスキングの結果,特徴量を除去しすぎた場合は
    処理をスキップします.

時間フレームのマスキング

時間軸に沿ってフレームをマスキングする処理の実装は下記のようになります.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class RandomDropTemporal():
    def __init__(self,
                 apply_ratio=1.0,
                 size=(0.1, 0.5)):
        self.apply_ratio = apply_ratio
        self.size = size

    def __call__(self,
                 data: Dict[str, Any]) -> Dict[str, Any]:
        if random.random() > self.apply_ratio:
            return data

        feature = data["feature"]
        mask = (feature == 0).all(axis=0)
        mask = np.bitwise_not(mask)

        # Calculate drop range.
        tlength = feature.shape[1]
        size = np.random.random() * (self.size[1] - self.size[0]) + self.size[0]
        start = np.random.random()
        end = start + size
        start = int(tlength * start)
        end = min(int(tlength * end), tlength)

        # Masking.
        mask[None, start: end, :] = 0.0

        # Avoid to drop all signals.
        if not (mask == 0.0).all():
            feature *= mask

        data["feature"] = feature
        return data
【コード解説】
- 引数
  - apply_ratio: データ拡張の適用確率.
  - size: 欠落フレーム数の範囲.
    (min, max) 形式で指定します.
- 5-6行目: 初期化処理
- 8-33行目: マスキング処理
  - 10-11行目: 乱数を生成し,`apply_ratio` 以上だった場合は何もせずに値を返す.
  - 14-15行目: マスキング配列を生成し,追跡失敗フレームのマスク値をセット.
  - 18-23行目: マスキングする時間範囲を算出.
  - 29-30行目: マスキングの適用.マスキングの結果,特徴量を除去しすぎた場合は
    処理をスキップします.

座標空間上のマスキング

座標空間上でマスキングする処理の実装は下記のようになります.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class RandomDropSpatial():
    def __init__(self,
                 apply_ratio=1.0,
                 size=(0.2, 0.4),
                 offsets=None,
                 smask_as="obstacle"):
        self.apply_ratio = apply_ratio
        self.size = size
        self.offsets = offsets
        assert smask_as in ["obstacle", "window"]
        self.smask_as = smask_as

    def __call__(self,
                 data: Dict[str, Any]) -> Dict[str, Any]:
        if random.random() > self.apply_ratio:
            return data

        feature = copy.deepcopy(data["feature"])
        mask = (feature == 0).all(axis=0)
        mask = np.bitwise_not(mask)

        # Maximum spatial range.
        minimums = np.min(feature, axis=(1, 2))
        maximums = np.max(feature, axis=(1, 2))

        min_x = minimums[0]
        min_y = minimums[1]
        max_x = maximums[0]
        max_y = maximums[1]

        # Calculate drop rectangle.
        if self.offsets is None:
            dr_offset_x = np.random.random() * (max_x - min_x) + min_x
            dr_offset_y = np.random.random() * (max_y - min_y) + min_y
        else:
            dr_offset_x = self.offsets[0] * (max_x - min_x) + min_x
            dr_offset_y = self.offsets[1] * (max_y - min_y) + min_y

        dr_size = np.random.random() * (self.size[1] - self.size[0]) + self.size[0]
        dr_size_x = (max_x - min_x) * dr_size
        dr_size_y = (max_y - min_y) * dr_size

        # Undrop in the window.
        udr_x = (dr_offset_x <= feature[0, :, :])
        udr_x = np.bitwise_and(udr_x, (feature[0, :, :] <= (dr_offset_x + dr_size_x)))
        udr_y = (dr_offset_y <= feature[1, :, :])
        udr_y = np.bitwise_and(udr_y, (feature[1, :, :] <= (dr_offset_y + dr_size_y)))
        undrop = np.bitwise_and(udr_x, udr_y)

        if minimums.shape[0] == 3:
            min_z = minimums[2]
            max_z = minimums[2]
            if self.offsets is None:
                dr_offset_z = np.random.random() * (max_z - min_z) + min_z
            else:
                dr_offset_z = self.offsets[2] * (max_z - min_z) + min_z
            dr_size_z = (max_z - min_z) * dr_size
            # Drop in the window.
            udr_z = (dr_offset_z <= feature[2, :, :])
            udr_z = np.bitwise_and(udr_z, (feature[2, :, :] <= (dr_offset_z + dr_size_z)))
            undrop = np.bitwise_and(undrop, udr_z)

        # Inverse undrop mask to interpret spatial mask as window.
        if self.smask_as == "obstacle":
            undrop = np.bitwise_not(undrop)

        mask = np.bitwise_and(mask, undrop)
        if not (mask == 0.0).all():
            feature *= mask

        return data
【コード解説】
- 引数
  - apply_ratio: データ拡張の適用確率.
  - size: マスキング領域のサイズ.
    (min, max) 形式で指定します.
  - offsets: マスキング領域の左上座標.
    (x, y) 形式で [0, 1] の範囲で指定します.
    `None` の場合はランダム値を用います.
  - smask_as: マスキング領域の適法方法を指定します.
    - obstacle: マスキング領域内の座標値を除去します.
      つまり,マスキング領域を,追跡点が見えなくなる障害物と仮定して処理します.
    - window: マスキング領域外の座標値を除去します.
      つまり,マスキング領域を,追跡点が見える窓と仮定して処理します.
- 7-11行目: 初期化処理
- 13-71行目: マスキング処理
  - 15-16行目: 乱数を生成し,`apply_ratio` 以上だった場合は何もせずに値を返す.
  - 19-20行目: マスキング配列を生成し,追跡失敗フレームのマスク値をセット.
  - 23-29行目: 内部処理用に座標空間の範囲を算出.
  - 32-37行目: マスキング領域の左上座標を算出.
  - 39-41行目: マスキング領域を算出.
  - 44-48行目: マスキング箇所を算出.
  - 50-61行目: 入力座標が (x, y, z) 形式だった場合は z値に対する処理を行う.
  - 64-65行目: `smask_as == obstacle` の場合は,マスキング箇所を反転.
  - 68-69行目: マスキングの適用.マスキングの結果,特徴量を除去しすぎた場合は
    処理をスキップします.

4. 認識モデルの動作確認

今回は,第九回の記事で紹介した,Pre-LN構成のTransformerモデルをそのまま用いて実験を行います.
ここではモデルの推論動作が正常に動くかだけ確かめます.

次のコードでデータセットからHDF5ファイルとJSONファイルのパスを読み込みます.

1
2
3
4
5
6
7
8
# Access check.
dataset_dir = Path("dataset_top10")
files = list(dataset_dir.iterdir())
dictionary = [fin for fin in files if ".json" in fin.name][0]
hdf5_files = [fin for fin in files if ".hdf5" in fin.name]

print(dictionary)
print(hdf5_files)
dataset_top10/sign_to_prediction_index_map.json
[PosixPath('dataset_top10/2044.hdf5'), PosixPath('dataset_top10/32319.hdf5'), PosixPath('dataset_top10/18796.hdf5'), PosixPath('dataset_top10/36257.hdf5'), PosixPath('dataset_top10/62590.hdf5'), PosixPath('dataset_top10/16069.hdf5'), PosixPath('dataset_top10/29302.hdf5'), PosixPath('dataset_top10/34503.hdf5'), PosixPath('dataset_top10/37055.hdf5'), PosixPath('dataset_top10/37779.hdf5'), PosixPath('dataset_top10/27610.hdf5'), PosixPath('dataset_top10/53618.hdf5'), PosixPath('dataset_top10/49445.hdf5'), PosixPath('dataset_top10/30680.hdf5'), PosixPath('dataset_top10/22343.hdf5'), PosixPath('dataset_top10/55372.hdf5'), PosixPath('dataset_top10/26734.hdf5'), PosixPath('dataset_top10/28656.hdf5'), PosixPath('dataset_top10/61333.hdf5'), PosixPath('dataset_top10/4718.hdf5'), PosixPath('dataset_top10/25571.hdf5')]

次のコードで辞書ファイルをロードして,認識対象の単語数を格納します.

1
2
3
4
5
# Load dictionary.
with open(dictionary, "r") as fread:
    key2token = json.load(fread)

VOCAB = len(key2token)

次のコードで前処理を定義します.
固定の前処理には,以前に説明した追跡点の選定と,追跡点の正規化を適用して実験を行います.

データ拡張処理は動的な前処理として,transforms_drop_parts (20-23行目,部位毎),transforms_drop_joints (25-30行目,個々の追跡点),transforms_drop_temporal (32-36行目,時間フレーム),transforms_drop_spatial_o (38-44行目,座標空間),transforms_drop_spatial_w (46-52行目,座標空間,マスキング領域を反転) に,それぞれ定義しています.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
_, use_landmarks = get_fullbody_landmarks()
use_features = ["x", "y"]
trans_select_feature = SelectLandmarksAndFeature(landmarks=use_landmarks, features=use_features)
trans_repnan = ReplaceNan()
trans_norm = PartsBasedNormalization(align_mode="framewise", scale_mode="unique")

apply_ratio = 0.5
drop_ratio = 0.5
drop_tsize = 3
drop_srange_o = (0.1, 0.3)
drop_srange_w = (0.7, 0.9)
drop_trange = (0.1, 0.5)

pre_transforms = Compose([trans_select_feature,
                          trans_repnan,
                          trans_norm])

transforms_default = Compose([ToTensor()])

transforms_drop_parts = Compose([
    RandomDropParts(
        apply_ratio=apply_ratio),
    ToTensor()])

transforms_drop_joints = Compose([
    RandomDropJoints(
        apply_ratio=apply_ratio,
        drop_ratio=drop_ratio,
        drop_tsize=drop_tsize),
    ToTensor()])

transforms_drop_temporal = Compose([
    RandomDropTemporal(
        apply_ratio=apply_ratio,
        size=drop_trange),
    ToTensor()])

transforms_drop_spatial_o = Compose([
    RandomDropSpatial(
        apply_ratio=apply_ratio,
        size=drop_srange_o,
        offsets=None,
        smask_as="obstacle"),
    ToTensor()])

transforms_drop_spatial_w = Compose([
    RandomDropSpatial(
        apply_ratio=apply_ratio,
        size=drop_srange_w,
        offsets=None,
        smask_as="window"),
    ToTensor()])

次のコードで,前処理を適用したHDF5DatasetとDataLoaderをインスタンス化し,データを取り出します.
HDF5Dataset をインスタンス化する際に,pre_transformstransforms 引数に変数を渡してデータ拡張を有効にしています (14行目).

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
batch_size = 2
feature_shape = (len(use_features), -1, len(use_landmarks))
token_shape = (1,)
merge_fn = partial(merge_padded_batch,
                   feature_shape=feature_shape,
                   token_shape=token_shape,
                   feature_padding_val=0.0,
                   token_padding_val=0)

for trans in [transforms_default,
              transforms_drop_parts, transforms_drop_joints,
              transforms_drop_temporal,
              transforms_drop_spatial_o, transforms_drop_spatial_w]:
    dataset = HDF5Dataset(hdf5_files, pre_transforms=pre_transforms, transforms=trans)
    dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=merge_fn)
    try:
        data = next(iter(dataloader))
        feature_origin = data["feature"]

        print(feature_origin.shape)
    except Exception as inst:
        print(inst)
torch.Size([2, 2, 28, 130])
torch.Size([2, 2, 28, 130])
torch.Size([2, 2, 28, 130])
torch.Size([2, 2, 28, 130])
torch.Size([2, 2, 28, 130])
torch.Size([2, 2, 28, 130])

次のコードでモデルをインスタンス化して,動作チェックをします.
追跡点抽出の結果,入力追跡点数は130で,各追跡点はXY座標値を持っていますので,入力次元数は260になります.
出力次元数は単語数なので10になります.
また,Transformer層の入力次元数は64に設定し,PFFN内部の拡張次元数は256に設定しています.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# Define model.
# in_channels: J * C (130*2=260)
#   J: use_landmarks (130)
#   C: use_channels (2)
# out_channels: 10
in_channels = len(use_landmarks) * len(use_features)
inter_channels = 64
out_channels = VOCAB
activation = "relu"
tren_num_layers = 2
tren_num_heads = 2
tren_dim_ffw = 256
tren_dropout_pe = 0.1
tren_dropout = 0.1
tren_layer_norm_eps = 1e-5
tren_norm_first = True
tren_add_bias = True
tren_add_tailnorm = True

model = TransformerEnISLR(in_channels=in_channels,
                          inter_channels=inter_channels,
                          out_channels=out_channels,
                          activation=activation,
                          tren_num_layers=tren_num_layers,
                          tren_num_heads=tren_num_heads,
                          tren_dim_ffw=tren_dim_ffw,
                          tren_dropout_pe=tren_dropout_pe,
                          tren_dropout=tren_dropout,
                          tren_layer_norm_eps=tren_layer_norm_eps,
                          tren_norm_first=tren_norm_first,
                          tren_add_bias=tren_add_bias,
                          tren_add_tailnorm=tren_add_tailnorm)
print(model)

# Sanity check.
logit = model(feature_origin)
print(logit.shape)
attw0 = model.tr_encoder.layers[0].attw.detach().cpu().numpy()
attw1 = model.tr_encoder.layers[0].attw.detach().cpu().numpy()
print(attw0.shape, attw1.shape)
TransformerEnISLR(
  (linear): Linear(in_features=260, out_features=64, bias=True)
  (activation): ReLU()
  (tr_encoder): TransformerEncoder(
    (pos_encoder): PositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (w_key): Linear(in_features=64, out_features=64, bias=True)
          (w_value): Linear(in_features=64, out_features=64, bias=True)
          (w_query): Linear(in_features=64, out_features=64, bias=True)
          (w_out): Linear(in_features=64, out_features=64, bias=True)
          (dropout_attn): Dropout(p=0.1, inplace=False)
        )
        (ffw): PositionwiseFeedForward(
          (w_1): Linear(in_features=64, out_features=256, bias=True)
          (w_2): Linear(in_features=256, out_features=64, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (activation): ReLU()
        )
        (dropout): Dropout(p=0.1, inplace=False)
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      )
    )
    (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  )
  (head): GPoolRecognitionHead(
    (head): Linear(in_features=64, out_features=10, bias=True)
  )
)
torch.Size([2, 10])
(2, 2, 28, 28) (2, 2, 28, 28)

5. 学習と評価の実行

5.1 共通設定

では,実際に学習・評価を行います.
まずは,実験全体で共通して用いる設定値を次のコードで実装します.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
# Set common parameters.
batch_size = 32
load_into_ram = True
test_pid = 16069
num_workers = os.cpu_count()
print(f"Using {num_workers} cores for data loading.")
lr = 3e-4

epochs = 50
eval_every_n_epochs = 1
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} for computation.")

train_hdf5files = [fin for fin in hdf5_files if str(test_pid) not in fin.name]
val_hdf5files = [fin for fin in hdf5_files if str(test_pid) in fin.name]
test_hdf5files = [fin for fin in hdf5_files if str(test_pid) in fin.name]

_, use_landmarks = get_fullbody_landmarks()
use_features = ["x", "y"]
Using 2 cores for data loading.
Using cuda for computation.

5.2 学習・評価の実行

次のコードで学習・バリデーション・評価処理それぞれのためのDataLoaderクラスを作成します.
今回は,データ拡張処理の有無および種類による認識性能の違いを見たいので,実験毎にデータセットクラスをインスタンス化します.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
# Build dataloaders.
train_dataset = HDF5Dataset(
    train_hdf5files,
    pre_transforms=pre_transforms_w_norm,
    transforms=transforms_default,
    load_into_ram=load_into_ram)
val_dataset = HDF5Dataset(
    val_hdf5files,
    pre_transforms=pre_transforms_w_norm,
    transforms=transforms_default,
    load_into_ram=load_into_ram)
test_dataset = HDF5Dataset(
    test_hdf5files,
    pre_transforms=pre_transforms_w_norm,
    transforms=transforms_default,
    load_into_ram=load_into_ram)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=merge_fn, num_workers=num_workers, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=merge_fn, num_workers=num_workers, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=1, collate_fn=merge_fn, num_workers=num_workers, shuffle=False)

次のコードでモデルをインスタンス化します.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
tren_norm_first = True
tren_add_tailnorm = True

model_default = TransformerEnISLR(
    in_channels=in_channels,
    inter_channels=inter_channels,
    out_channels=out_channels,
    activation=activation,
    tren_num_layers=tren_num_layers,
    tren_num_heads=tren_num_heads,
    tren_dim_ffw=tren_dim_ffw,
    tren_dropout_pe=tren_dropout_pe,
    tren_dropout=tren_dropout,
    tren_layer_norm_eps=tren_layer_norm_eps,
    tren_norm_first=tren_norm_first,
    tren_add_bias=tren_add_bias,
    tren_add_tailnorm=tren_add_tailnorm)
print(model_default)

loss_fn = nn.CrossEntropyLoss(reduction="mean")
optimizer = torch.optim.Adam(model_default.parameters(), lr=lr)
TransformerEnISLR(
  (linear): Linear(in_features=260, out_features=64, bias=True)
  (activation): ReLU()
  (tr_encoder): TransformerEncoder(
    (pos_encoder): PositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (w_key): Linear(in_features=64, out_features=64, bias=True)
          (w_value): Linear(in_features=64, out_features=64, bias=True)
          (w_query): Linear(in_features=64, out_features=64, bias=True)
          (w_out): Linear(in_features=64, out_features=64, bias=True)
          (dropout_attn): Dropout(p=0.1, inplace=False)
        )
        (ffw): PositionwiseFeedForward(
          (w_1): Linear(in_features=64, out_features=256, bias=True)
          (w_2): Linear(in_features=256, out_features=64, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (activation): ReLU()
        )
        (dropout): Dropout(p=0.1, inplace=False)
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      )
    )
    (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  )
  (head): GPoolRecognitionHead(
    (head): Linear(in_features=64, out_features=10, bias=True)
  )
)

次のコードで学習・評価処理を行います.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Train, validation, and evaluation.
model_default.to(device)

train_losses = []
val_losses = []
test_accs = []
print("Start training.")
for epoch in range(epochs):
    print("-" * 80)
    print(f"Epoch {epoch+1}")

    train_losses = train_loop(train_dataloader, model_default, loss_fn, optimizer, device)
    val_loss = val_loop(val_dataloader, model_default, loss_fn, device)
    val_losses.append(val_loss)

    if (epoch+1) % eval_every_n_epochs == 0:
        acc = test_loop(test_dataloader, model_default, device)
        test_accs.append(acc)
train_losses_default = np.array(train_losses)
val_losses_default = np.array(val_losses)
test_accs_default = np.array(test_accs)

print(f"Minimum validation loss:{val_losses_default.min()} at {np.argmin(val_losses_default)+1} epoch.")
print(f"Maximum accuracy:{test_accs_default.max()} at {np.argmax(test_accs_default)*eval_every_n_epochs+1} epoch.")
Start training.
--------------------------------------------------------------------------------
Epoch 1
Start training.
loss:3.446782 [    0/ 3881]
loss:2.071606 [ 3200/ 3881]
Done. Time:4.922797406000001
Training performance:
 Avg loss:2.194979

Start validation.
Done. Time:0.4103152080000143
Validation performance:
 Avg loss:1.961810

Start evaluation.
Done. Time:1.802768350000008
Test performance:
 Accuracy:28.0%
--------------------------------------------------------------------------------
...
--------------------------------------------------------------------------------
Epoch 50
Start training.
loss:0.223539 [    0/ 3881]
loss:0.172815 [ 3200/ 3881]
Done. Time:2.817799521999973
Training performance:
 Avg loss:0.185778

Start validation.
Done. Time:0.25698553999995966
Validation performance: 
 Avg loss:0.779501

Start evaluation.
Done. Time:1.450860308000017
Test performance: 
 Accuracy:79.5%
Minimum validation loss:0.5978050615106311 at 17 epoch.
Maximum accuracy:82.5 at 42 epoch.

以後,同様の処理を設定毎に繰り返します.
コード構成は同じですので,ここでは説明を割愛させていただきます. また,この後グラフ等の描画も行っておりますが,本記事の主要点ではないため説明を割愛させていただきます.


今回は追跡点系列に対してマスキングを適用することで,データ拡張を行う手法を紹介しましたが,如何でしたでしょうか?
マスキング処理はデータ拡張 [DeVries'17, Singh'17, Zhong'20, Chen'20, Li'20] だけでなく,自己教師あり学習 [Pathak'16, Devlin'19, Bao'22, Xie'22, Balestriero'23] などでも用いられており,注目を集めています.
実装もシンプルな手法が多いので,改善に悩んでいる方はご検討してみては如何でしょうか.

全データを使うと結果がハッキリ出て良いですね.
(その分時間はかかってしまいますが...)
これまでやってきた実験もデータ数が少ないことの影響は少なからずありそうです.
折を見て追加実験をして,記事を更新していこうかなと思っています.

今回紹介した話が,これから手話認識を勉強してみようとお考えの方に何か参考になれば幸いです.

  • [DeVries'17]: T. DeVries, et al., "Improved Regularization of Convolutional Neural Networks with Cutout," arXiV: 1708.04552, available here, 2017.
  • [Singh'17]: K. K. Singh, et al., "Hide-and-Seek: Forcing a Network to be Meticulous for Weakly-supervised Object and Action Localization," Proc. of the ICCV, available here, 2017.
  • [Zhong'20]: Z. Zhong, et al., "Random Erasing Data Augmentation," Proc. of the AAAI, available here, 2020.
  • [Chen'20]: P. Chen, et al., "GridMask Data Augmentation," arXiV: 2001.04086, available here, 2020.
  • [Li'20]: P. Li, et al., "FenceMask: A Data Augmentation Approach for Pre-extracted Image Features," arXiV: 2006.07877, available here, 2020.
  • [Pathak'16]: D. Pathak, et al., "Context Encoders: Feature Learning by Inpainting," Proc. of the CVPR, available here, 2016.
  • [Devlin'19]: J. Devlin, et al., "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding," Proc. of the NAACL, available here, 2019.
  • [Bao'22]: H. Bao, et al., "BEiT: BERT Pre-Training of Image Transformers," Proc of the ICLR, available here, 2022.
  • [Xie'22]: Z. Xie, et al., "SimMIM: A Simple Framework for Masked Image Modeling," Proc of the CVPR, available here, 2022.
  • [Balestriero'23]: R. Balestriero, et al., "A Cookbook of Self-Supervised Learning," arXiV: 2304.12210, available here, 2023.