欠損値の線形補間処理・PyTorch編

This image is generated with ChatGPT-4, and edited by the author.
作成日:2023年10月05日(木) 00:00
最終更新日:2024年09月17日(火) 20:10
カテゴリ:コンピュータビジョン
タグ:  線形補間 PyTorch 動作解析 骨格追跡 Python

PyTorch の行列計算を用いて,骨格追跡点の欠損値を線形補間する方法を紹介します.

こんにちは.高山です.
以前の記事で,Numpyの行列計算を用いて線形補間を行う方法を紹介しました.
今回は,同様の処理をPyTorchを用いて実装する方法を紹介します.
今回解説するスクリプトはGitHub上に公開しています

本記事の実装方法は,Brent M. Spell氏のブログに記載の方法を参考にしています.
この度,本記事向けに拡張したコードを記載してよいかお伺いしたところ,快く許可してくださいました.
この場を借りて感謝申し上げます.

更新履歴 (大きな変更のみ記載しています)

  • 2024/09/17: タグとサマリを更新しました
  • 2023/10/30
    • 処理時間の計測方法を更新しました.
    • GPUを利用できるように更新しました.
    • JITの最適化を抑制する処理を加えました

1. モジュールのロード

まず最初に,下記のコードでモジュールをロードします.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
# Standard modules.
import gc
import sys
import time
from functools import partial

# CV/ML.
import numpy as np

import torch
import torch.nn.functional as F
【コード解説】
- 標準モジュール
  - gc: ガベージコレクション用ライブラリ
    処理時間計測クラスの内部処理で用います.
  - sys: Pythonシステムを扱うライブラリ
    今回はバージョン情報を取得するために使用しています.
  - time: 時刻データを取り扱うためのライブラリ
    今回は処理時間を計測するために使用しています.
  - functools: 関数オブジェクトを操作するためのライブラリ
    処理時間計測クラスに渡す関数オブジェクト作成に使用しています.
- 画像処理・機械学習向けモジュール
  - numpy: 行列演算ライブラリ
    今回はデータのロードに用います.
  - torch: ニューラルネットワークライブラリ
    今回は行列演算機能を用います.

記事執筆時点のPythonと主要モジュールのバージョンは下記のとおりです.

1
2
3
print(f"Python:{sys.version}")
print(f"Numpy:{np.__version__}")
print(f"PyTorch:{torch.__version__}")
Python:3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0]
Numpy:1.23.5
PyTorch:2.1.0+cu118

2. 入力データのロード

次に,下記の処理で補間前と補間後の追跡点データをダウンロードします.
補間後の追跡点データは以前の記事で紹介した処理で生成したデータです.
このデータは今回の処理で生成したデータとの比較に用います.

!wget https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static.npy
!wget https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static_interp.npy

ls コマンドでデータがダウンロードされているか確認します.

!ls
finger_far0_non_static_interp.npy  finger_far0_non_static.npy  sample_data

3. 実験用処理の実装

実験に先立って,処理時間計測用の関数と実験用定数を定義します.
次のコードは処理時間の値から,適切なSI接頭辞を設定し,文字列にして返します.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
def get_perf_str(val):
    token_si = ["", "m", "µ", "n", "p"]
    exp_si = [1, 1e3, 1e6, 1e9, 1e12]
    perf_str = f"{val:3g}s"
    si = ""
    sval = val
    for token, exp in zip(token_si, exp_si):
        if val * exp > 1.0:
            si = token
            sval = val * exp
            break
    perf_str = f"{sval:3g}{si}s"
    return perf_str
- 引数
  - val: 処理時間を示す値 (秒)
- 2-6行目: 変数の初期化
- 7-11行目: 接頭辞を選択して,表示値を調整
- 12行目: 表示文字列を作成

次のコードは,処理時間を格納した配列から統計量を求めて表示します.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
def print_perf_time(intervals, top_k=None):
    if top_k is not None:
        intervals = np.sort(intervals)[:top_k]
    min = intervals.min()
    max = intervals.max()
    mean = intervals.mean()
    std = intervals.std()

    smin = get_perf_str(min)
    smax = get_perf_str(max)
    mean = get_perf_str(mean)
    std = get_perf_str(std)
    if top_k:
        print(f"Top {top_k} summary: Max {smax}, Min {smin}, Mean +/- Std {mean} +/- {std}")
    else:
        print(f"Overall summary: Max {smax}, Min {smin}, Mean +/- Std {mean} +/- {std}")
- 引数:
  - intervals: 処理時間のNumpy配列
  - top_k: intervalsから,処理時間が短いサンプルをk個取り出して統計量を算出
- 2-3行目: Top K個のサンプル抽出
- 4-7行目: 統計量算出
- 9-16行目: 表示文字列を作成して表示

次のクラスは,入力した関数を複数回呼び出して,処理時間の統計量を表示します.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
class PerfMeasure():
    def __init__(self,
                 trials=100,
                 top_k=10):
        self.trials = trials
        self.top_k = top_k

    def __call__(self, func):
        gc.collect()
        gc.disable()
        intervals = []
        for _ in range(self.trials):
            start = time.perf_counter()
            func()
            end = time.perf_counter()
            intervals.append(end - start)
        intervals = np.array(intervals)
        print_perf_time(intervals)
        if self.top_k:
            print_perf_time(intervals, self.top_k)
        gc.enable()
        gc.collect()
- 引数:
  - trials: 入力関数の実行回数
  - top_k: 値が定義されている場合,全体の処理時間配列から処理時間が短いサンプルをk個取り出して統計量を算出
- 9-10行目: 計測中はガベージコレクションをしないように設定
- 11-20行目: 処理時間計測処理
- 21-22行目: ガベージコレクションの設定を元に戻す

最後に,次のコードで実験用の定数を定義して処理時間計測クラスをインスタンス化しています.

1
2
3
4
5
6
TRIALS = 100
TOPK = 10
pmeasure = PerfMeasure(TRIALS, TOPK)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Target device is {DEVICE}.")
JIT_OPT = False

1-3行目は処理時間計測用の処理です.
実際の所,Colabの様な複雑な処理環境で他のプロセスの影響を排除して純粋な処理時間を計測するのはかなり難しいです.
そこでここでは,処理の繰り返し数を \(100\) とし全体の統計量に加えて,処理時間の短い代表値 \(10\) 試行からも統計量を表示するようにしています.

4-5行目ではランタイプを種類に応じて,CPU/GPUを切り替えています.

6行目はJITコンパイル時の最適化を抑制する設定です.
最適化を抑制しない場合に処理の再コンパイルが発生する現象が報告されています.
公式ドキュメントに記載が無いので何とも言えないのですが,本記事では最適化を抑制して実験を行います.

4. 線形補間の実装・実行

ここから先は補間処理の実装・実行を行います.
Tensorflowと同じくPyTorchも,プログラムの実行時に処理をコンパイル (計算グラフの作成) して実行します.
現在のPyTorchには,下記に示すとおり2種類の計算グラフ作成方法があります.

  • Define-by-Run: データを入力した際に処理のコンパイルと実行を同時に行う.
    Pythonの実行環境と親和性が高くインタラクティブな実行環境で利用しやすくなっています.
  • Define-and-Run (TorchScript): (C言語のように) 処理のコンパイルと実行が明確に別れて行われる.
    処理の最適化性能が高いため高速に動作します.

今回は,補間処理をそれぞれの計算グラフ作成方法で実行して処理時間を比較してみたいと思います.

4.1 Define-by-Runを用いた場合

本項ではまず,Define-by-Run を用いた場合の実装を紹介します.
PyTorchでは標準でDefine-by-Runに基づいて計算グラフを構築するようになっていますので,特に難しいところはありません.

補間処理の実装

次のコードでは,こちらの記事の第2節で説明した行列計算を用いた線形補間処理を実装しています.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def matrix_interp_torch(track):
    orig_shape = track.shape
    tlength = orig_shape[0]
    mask = track[:, 0, -1] != 0
    valid = mask.sum()
    if valid == tlength:
        return track

    xs = torch.where(mask != 0)[0]
    ys = track.reshape([tlength, -1])[xs, :]
    x = torch.arange(tlength, device=xs.device)

    # ========================================================================
    # Interpolation.
    # ========================================================================
    xs = xs.to(ys.dtype)
    x = x.to(ys.dtype)
    # Pad control points for extrapolation.
    # Unexpectedly, torch.finfo(torch.float64).min returns -inf.
    # So we use torch.finfo(torch.float32).min alternatively.
    xs = torch.cat([torch.tensor([torch.finfo(torch.float32).min], device=xs.device),
                    xs,
                    torch.tensor([torch.finfo(torch.float32).max], device=xs.device)], dim=0)
    ys = torch.cat([ys[:1], ys, ys[-1:]], dim=0)

    # Compute slopes, pad at the edges to flatten.
    sloops = (ys[1:] - ys[:-1]) / torch.unsqueeze((xs[1:] - xs[:-1]), dim=-1)
    sloops = F.pad(sloops[:-1], (0, 0, 1, 1))

    # Solve for intercepts.
    intercepts = ys - sloops * torch.unsqueeze(xs, dim=-1)

    # Search for the line parameters at each input data point.
    # Create a grid of the inputs and piece breakpoints for thresholding.
    # Rely on argmax stopping on the first true when there are duplicates,
    # which gives us an index into the parameter vectors.
    mask_bk_indicator = torch.unsqueeze(xs, dim=-2) > torch.unsqueeze(x, dim=-1)
    idx = torch.argmax(mask_bk_indicator.to(torch.int32), dim=-1)
    sloop = sloops[idx]
    intercept = intercepts[idx]

    # Apply the linear mapping at each input data point.
    y = sloop * torch.unsqueeze(x, dim=-1) + intercept
    y = y.to(ys.dtype)
    y = y.reshape(orig_shape)
    return y
【コード解説】
- 引数
  - track: `[T, J, C]` 形状の追跡点配列.欠損値はゼロ埋めされている必要があります.
    また,この関数は部位毎に追跡点配列が入力される (全追跡点,特徴量で欠損フレームが共通) ことを想定しています.
    - T: 動画フレームインデクス
    - J: 追跡点インデクス
    - C: 特徴量インデクス.今回は $(x, y, z, c)$ の4次元特徴量を用いています.
- 2-3行目: 内部処理で用いるため,`track` の形状情報を取得
- 4行目: 第0番目の追跡点の信頼度に基づいて,追跡成功フレームを示す `mask` を生成
  この関数は全追跡点,特徴量で欠損フレームが共通であることを前提としています.
- 5-7行目: 5行目で追跡成功フレーム数を取得し,線形補間が必要かを判定
- 9-11行目: 追跡成功フレームのデータ点と,補間箇所を含む全フレームのインデクス配列を生成
- 16-17行目: 計算処理のために型を `ys` に合わせて変換
- 21-22行目: `xs` と `ys` の先頭と末尾にダミー値を追加
- 25-26行目: 直線の傾き行列を算出し,先頭にダミー値を追加
- 29行目: 直線の切片行列を算出
- 35-38行目: ルックアップテーブルを作成して,全フレームのインデクスに対する直線パラメータを取得
- 41-43行目: 線形補間を実行し,型と形状を入力に合わせて返す

全部位の補間処理

全部位の追跡処理は次の関数で実装してます.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
def partsbased_interp_torch(trackdata, device="cpu"):
    trackdata = torch.from_numpy(trackdata).to(device)
    pose = trackdata[:, :33]
    lhand = trackdata[:, 33:33+21]
    rhand = trackdata[:, 33+21:33+21+21]
    face = trackdata[:, 33+21+21:]

    pose = matrix_interp_torch(pose)
    lhand = matrix_interp_torch(lhand)
    rhand = matrix_interp_torch(rhand)
    face = matrix_interp_torch(face)
    return torch.cat([pose, lhand, rhand, face], dim=1)
1
2
3
4
5
6
7
8
9
【コード解説】
- 引数:
  - trackdata: `[T, J, C]` 形状の追跡点配列.欠損値はゼロ埋めされている必要があります.
    こちらの追跡点には全部位のデータが含まれていることを想定しています.
  - device: torch.Tensorを取り扱うデバイスを指定 [cpu/cuda].
- 3行目: `trackdata` を `numpy.ndarray` から `torch.Tensor` に変換
- 4-7行目: 追跡点を部位毎に分割
- 9-12行目: 部位毎に線形補間を実行
- 13行目: 部位毎の追跡点を結合して返す

補間処理の実行

線形補間に必要な処理が実装できましたので,上でダウンロードしたデータを用いて処理を実行します.

次のコードでは,追跡点をロードして,関数の仕様に合わせて形状を変更しています.
追跡点データは[P, T, J, C]形状 (Pは人物インデクス) の配列ですが,人物は1名だけですのでP軸は除去しています.

1
2
3
4
5
trackdata = np.load("finger_far0_non_static.npy")
reftrack = np.load("finger_far0_non_static_interp.npy")
# Remove person axis.
trackdata = trackdata[0]
reftrack = reftrack[0]

次のコードは先程実装したpartsbased_interp_torchtrackdataを入力して,補間後の追跡点newtrackを得ています.
1回目の処理呼び出しでは,処理のコンパイルが行われるため個別に時間を計測し,参照用の補間後追跡点との誤差を計測しています.
その後,処理時間計測クラスを用いて平均処理時間を表示しています.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
# Torch.
# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.time()
newtrack = partsbased_interp_torch(trackdata, device=DEVICE)
interval = time.time() - start
print_perf_time(np.array([interval]))

diff = (reftrack - newtrack.detach().cpu().numpy()).sum()
print(f"Sum of error:{diff}")

print("Time after second call.")
target_fn = partial(partsbased_interp_torch, trackdata=trackdata, device=DEVICE)
pmeasure(target_fn)

print 処理の結果を示します.
1回目の処理時間は147.7ミリ秒程度で,誤差はほぼありませんでした. 2回目以降の処理では,(他のプロセスの影響が少ない場合は) 1.4ミリ秒程度で処理が完了しています.

Time of first call.
Overall summary: Max 147.673ms, Min 147.673ms, Mean +/- Std 147.673ms +/-   0s
Sum of error:-6.935119145623503e-12
Time after second call.
Overall summary: Max 20.5691ms, Min 1.33844ms, Mean +/- Std 2.20689ms +/- 2.63837ms
Top 10 summary: Max 1.40347ms, Min 1.33844ms, Mean +/- Std 1.37999ms +/- 22.2161µs

入力データ形状が変わった場合の挙動

Tensorflowの記事と同じく,今回も入力データの形状が変わった場合の挙動を調査してみたいと思います.
次のコードは先程の線形補間の実行処理とほぼ同じですが,入力データの時間長が1フレーム分短くなっています.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
# Torch.
# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.time()
newtrack = partsbased_interp_torch(trackdata[:-1], device=DEVICE)
interval = time.time() - start
print_perf_time(np.array([interval]))

diff = (reftrack[:-1] - newtrack.detach().cpu().numpy()).sum()
print(f"Sum of error:{diff}")

print("Time after second call.")
target_fn = partial(partsbased_interp_torch, trackdata=trackdata[:-1], device=DEVICE)
pmeasure(target_fn)

print 処理の結果を示します.
1回目の処理では10.8ミリ秒程度,2回目以降の処理では1.2ミリ秒程度で処理が完了しています.
Tensorflowと異なり,PyTorchの公式ドキュメントには再トレーシングに関する情報が記載されていません.
そのため確定的ではないですが,1回目と2回目以降の処理時間を見る限りでは処理の再トレーシングは起きていないようです.

Time of first call.
Overall summary: Max 10.7508ms, Min 10.7508ms, Mean +/- Std 10.7508ms +/-   0s
Sum of error:-6.935119145623503e-12
Time after second call.
Overall summary: Max 6.57484ms, Min 1.06022ms, Mean +/- Std 1.59762ms +/- 723.317µs
Top 10 summary: Max 1.30698ms, Min 1.06022ms, Mean +/- Std 1.23992ms +/- 67.8562µs

4.2 Define-and-Runモード (TorchScript) を用いた場合

補間処理の実装

ここから先は,Define-and-Run に基づいて計算グラフを作成した場合の挙動について見ていきます.

PyTorchでは @torch.jit.script デコレータで修飾した関数は TorchScript と呼ぶ中間表現に事前変換されます.
今回は Define-and-Run で動作するように線形補間処理を実装しているので,コードの変更はほとんどありません.

次のコードでは,先程紹介した線形補間処理関数を Define-and-Run で動作するようにしています.
中身のコードは完全に同じなのでここでは説明を省かさせていただきます.

1
2
3
@torch.jit.script
def matrix_interp_torch_jit(track):
    ...

補間処理の実行

次のコードで補間処理を実行しています.
先程と同じく,1回目の処理呼び出しでは,処理のコンパイルが行われるため個別に時間を計測し,2回目以降に処理時間計測クラスを用いて平均時間を表示しています.
なお,次のコードではwith 文でJITの最適化を抑制する設定をしています.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
# Torch.
with torch.jit.optimized_execution(JIT_OPT):
    # The 1st call may be slow because of the computation graph construction.
    print(f"Time of first call.")
    start = time.time()
    newtrack = partsbased_interp_torch_jit(trackdata, device=DEVICE)
    interval = time.time() - start
    print_perf_time(np.array([interval]))

    diff = (reftrack - newtrack.detach().cpu().numpy()).sum()
    print(f"Sum of error:{diff}")

    print("Time after second call.")
    target_fn = partial(partsbased_interp_torch_jit, trackdata=trackdata, device=DEVICE)
    pmeasure(target_fn)

print 処理の結果を示します.
ここの結果は少し興味深いです.
1回目の処理では処理のコンパイルが走るはずなので処理時間が長くなると予想されるのですが,実際には6.39ミリ秒程度で完了しています.
これはJITの最適化を抑制した効果と思われます.
JITの最適化を行う場合は121.7ミリ秒程度時間がかかりました.
2回目以降の処理では1.22ミリ秒程度で処理が完了しています.

Time of first call.
Overall summary: Max 6.38986ms, Min 6.38986ms, Mean +/- Std 6.38986ms +/-   0s
Sum of error:-6.935119145623503e-12
Time after second call.
Overall summary: Max 23.9719ms, Min 1.15678ms, Mean +/- Std 3.9527ms +/- 4.79447ms
Top 10 summary: Max 1.25467ms, Min 1.15678ms, Mean +/- Std 1.21958ms +/- 30.8053µs

入力データ形状が変わった場合の挙動

次に,先程と同じく入力データの形状が変わった場合の挙動を示します.
次のコードは先程の線形補間の実行処理とほぼ同じですが,入力データの時間長が1フレーム分短くなっています.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
# Torch.
with torch.jit.optimized_execution(JIT_OPT):
    # The 1st call may be slow because of the computation graph construction.
    print(f"Time of first call.")
    start = time.time()
    newtrack = partsbased_interp_torch_jit(trackdata[:-1], device=DEVICE)
    interval = time.time() - start
    print_perf_time(np.array([interval]))

    diff = (reftrack[:-1] - newtrack.detach().cpu().numpy()).sum()
    print(f"Sum of error:{diff}")

    print("Time after second call.")
    target_fn = partial(partsbased_interp_torch_jit, trackdata=trackdata[:-1], device=DEVICE)
    pmeasure(target_fn)

print 処理の結果を示します.
1回目の処理では2.17ミリ秒程度,2回目以降の処理では1.1ミリ秒程度で処理が完了しています.
1回目および2回目以降の処理時間は大きくは変わらず,PyTorchでは処理の再トレーシングは起きていないと予想されます.

Time of first call.
Overall summary: Max 2.16985ms, Min 2.16985ms, Mean +/- Std 2.16985ms +/-   0s
Sum of error:-6.935119145623503e-12
Time after second call.
Overall summary: Max 22.0589ms, Min 1.09035ms, Mean +/- Std 2.68458ms +/- 3.44788ms
Top 10 summary: Max 1.16749ms, Min 1.09035ms, Mean +/- Std 1.14482ms +/- 22.0932µs

今回は線形補間処理をPyTorchの行列計算で実装する方法を紹介しましたが,如何でしょうか?
私もTorchScriptを触ったのは今回が初めてで,まだ色々と勉強中です.
勘違いや誤りがあった場合はご指摘いただけましたら助かります.

今回紹介した話が,補間処理などでお悩みの方に何か参考になれば幸いです.