欠損値の線形補間処理・JAX編 (高速化はできませんでした)

This image is generated with ChatGPT-4, and edited by the author.
作成日:2023年10月06日(金) 00:00
最終更新日:2024年09月17日(火) 20:09
カテゴリ:コンピュータビジョン
タグ:  線形補間 JAX 動作解析 骨格追跡 Python

欠損値の線形補間処理をJAXの行列計算で実装してみましたが,高速化が上手くできませんでした.

こんにちは.高山です.
以前の記事で,Numpyの行列計算を用いて線形補間を行う方法を紹介しました.
今回は,同様の処理をJAXを用いて実装してみたいと思います.
JAXは,Autograd (自動微分) や XLA (線形代数高速化のためのコンパイラ) 機能を備えた,高性能な計算処理のために設計されたライブラリです.
Numpyに似たインタフェースで作られており,以前の記事で作成した手法を簡単に移植できるのでは?と思い,試してみることにしました.
今回解説するスクリプトはGitHub上に公開しています

本記事の実装方法は,Brent M. Spell氏のブログに記載の方法を参考にしています.
この度,本記事向けに拡張したコードを記載してよいかお伺いしたところ,快く許可してくださいました.
この場を借りて感謝申し上げます.

更新履歴 (大きな変更のみ記載しています)

  • 2024/09/17: タイトルとタグを更新しました
  • 2023/10/30: 処理時間の計測方法を更新しました

1. 先に結論から

実装紹介に先立って結論から述べますと,

  1. 移植は簡単にできた
  2. 入力データの形状が変わる場合への対応が上手くできなく,高速化には失敗

という結果でした.
Tensorflowの記事で,計算グラフの再トレーシングという問題について紹介しました.
上記の問題はJAXにも存在しており,(今回の実装方法を踏襲する限りでは) 上手い回避方法が実装できませんでした.

というわけで,ここから先は実質的には失敗報告になることをご了承ください(^^;).

2. モジュールのロード

まず最初に,下記のコードでモジュールをロードします.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
# Standard modules.
import gc
import sys
import time
from functools import partial

# CV/ML.
import numpy as np

import jax
import jax.numpy as jnp
from jax import jit

# Enable float64.
jax.config.update("jax_enable_x64", True)
【コード解説】
- 標準モジュール
  - gc: ガベージコレクション用ライブラリ
    処理時間計測クラスの内部処理で用います.
  - sys: Pythonシステムを扱うライブラリ
    今回はバージョン情報を取得するために使用しています.
  - time: 時刻データを取り扱うためのライブラリ
    今回は処理時間を計測するために使用しています.
  - functools: 関数オブジェクトを操作するためのライブラリ
    処理時間計測クラスに渡す関数オブジェクト作成に使用しています.
- 計算処理用モジュール
  - numpy: 行列演算ライブラリ.今回はデータをロードするために使用します.
  - jax: JAXライブラリのルート
  - jax.numpy: JAXベースの行列演算モジュール.Numpyと同様の関数群を持ちます.
  - jax.jit: Just in Time コンパイル用のモジュール.
    実装した処理をコンパイルして Define-and-Run ベースで動かすために使用します.
なお,JAXのfloat型は標準では 32bit になっているので,最後の行で64bitが使用されるように設定しています.

記事執筆時点のPythonと主要モジュールのバージョンは下記のとおりです.

1
2
3
print(f"Python:{sys.version}")
print(f"Numpy:{np.__version__}")
print(f"JAX:{jax.__version__}")
Python:3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0]
Numpy:1.23.5
JAX:0.4.16

3. 入力データのロード

次に,下記の処理で補間前と補間後の追跡点データをダウンロードします.
補間後の追跡点データは以前の記事で紹介した処理で生成したデータです.
このデータは今回の処理で生成したデータとの比較に用います.

!wget https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static.npy
!wget https://github.com/takayama-rado/trado_samples/raw/main/test_data/finger_far0_non_static_interp.npy

ls コマンドでデータがダウンロードされているか確認します.

!ls
finger_far0_non_static_interp.npy  finger_far0_non_static.npy  sample_data

4. 実験用処理の実装

実験に先立って,処理時間計測用の関数と実験用定数を定義します.
次のコードは処理時間の値から,適切なSI接頭辞を設定し,文字列にして返します.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
def get_perf_str(val):
    token_si = ["", "m", "µ", "n", "p"]
    exp_si = [1, 1e3, 1e6, 1e9, 1e12]
    perf_str = f"{val:3g}s"
    si = ""
    sval = val
    for token, exp in zip(token_si, exp_si):
        if val * exp > 1.0:
            si = token
            sval = val * exp
            break
    perf_str = f"{sval:3g}{si}s"
    return perf_str
- 引数
  - val: 処理時間を示す値 (秒)
- 2-6行目: 変数の初期化
- 7-11行目: 接頭辞を選択して,表示値を調整
- 12行目: 表示文字列を作成

次のコードは,処理時間を格納した配列から統計量を求めて表示します.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
def print_perf_time(intervals, top_k=None):
    if top_k is not None:
        intervals = np.sort(intervals)[:top_k]
    min = intervals.min()
    max = intervals.max()
    mean = intervals.mean()
    std = intervals.std()

    smin = get_perf_str(min)
    smax = get_perf_str(max)
    mean = get_perf_str(mean)
    std = get_perf_str(std)
    if top_k:
        print(f"Top {top_k} summary: Max {smax}, Min {smin}, Mean +/- Std {mean} +/- {std}")
    else:
        print(f"Overall summary: Max {smax}, Min {smin}, Mean +/- Std {mean} +/- {std}")
- 引数:
  - intervals: 処理時間のNumpy配列
  - top_k: intervalsから,処理時間が短いサンプルをk個取り出して統計量を算出
- 2-3行目: Top K個のサンプル抽出
- 4-7行目: 統計量算出
- 9-16行目: 表示文字列を作成して表示

次のクラスは,入力した関数を複数回呼び出して,処理時間の統計量を表示します.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
class PerfMeasure():
    def __init__(self,
                 trials=100,
                 top_k=10):
        self.trials = trials
        self.top_k = top_k

    def __call__(self, func):
        gc.collect()
        gc.disable()
        intervals = []
        for _ in range(self.trials):
            start = time.perf_counter()
            func()
            end = time.perf_counter()
            intervals.append(end - start)
        intervals = np.array(intervals)
        print_perf_time(intervals)
        if self.top_k:
            print_perf_time(intervals, self.top_k)
        gc.enable()
        gc.collect()
- 引数:
  - trials: 入力関数の実行回数
  - top_k: 値が定義されている場合,全体の処理時間配列から処理時間が短いサンプルをk個取り出して統計量を算出
- 9-10行目: 計測中はガベージコレクションをしないように設定
- 11-20行目: 処理時間計測処理
- 21-22行目: ガベージコレクションの設定を元に戻す

最後に,次のコードで実験用の定数を定義して処理時間計測クラスをインスタンス化しています.

1
2
3
TRIALS = 100
TOPK = 10
pmeasure = PerfMeasure(TRIALS, TOPK)

実際の所,Colabの様な複雑な処理環境で他のプロセスの影響を排除して純粋な処理時間を計測するのはかなり難しいです.
そこでここでは,処理の繰り返し数を \(100\) とし全体の統計量に加えて,処理時間の短い代表値 \(10\) 試行からも統計量を表示するようにしています.

5. 線形補間の実装・実行

ここから先は補間処理の実装・実行を行います.
TensorflowやPyTorchと同じく,JAXもプログラムの実行時に処理をコンパイル (計算グラフの作成) して実行します.
現在のJAXには,下記に示すとおり2種類の計算グラフ作成方法があります.

  • Define-by-Run: データを入力した際に処理のコンパイルと実行を同時に行う.
    Pythonの実行環境と親和性が高くインタラクティブな実行環境で利用しやすくなっています.
  • Define-and-Run (JITコンパイル): (C言語のように) 処理のコンパイルと実行が明確に別れて行われる.
    処理の最適化性能が高いため高速に動作します.

今回は,補間処理をそれぞれの計算グラフ作成方法で実行して処理時間を比較してみたいと思います.

5.1 Define-by-Runを用いた場合

本項ではまず,Define-by-Run を用いた場合の実装を紹介します.
JAXでは標準でDefine-by-Runに基づいて計算グラフを構築するようになっていますので,特に難しいところはありません.

補間関数を用いた実装

次のコードでは,こちらの記事と同じく,Numpy (今回はjax.numpy)の線形補間関数を用いた処理を実装しています.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
def simple_interp_jax(trackdata):
    tlength, num_joints, _ = trackdata.shape
    newtrack = jnp.zeros_like(trackdata)
    for i in range(num_joints):
        temp = trackdata[:, i, :]
        mask = temp[:, -1] != 0
        valid = mask.sum()
        if valid == tlength:
            newtrack = newtrack.at[:, i].set(temp)
            continue
        xs = jnp.where(mask != 0, size=valid)[0]
        # ys = temp[xs, :] <- can't be compiled.
        ys = jnp.take(temp, xs, axis=0)
        newys = jnp.zeros_like(temp)
        for j in range(temp.shape[-1]):
            newy = jnp.interp(jnp.arange(tlength), xs, ys[:, j])
            newys = newys.at[:, j].set(newy)
        newtrack = newtrack.at[:, i].set(newys)
    return newtrack
【コード解説】
- 引数
  - trackdata: `[T, J, C]` 形状の追跡点配列.欠損値はゼロ埋めされている必要があります.
    - T: 動画フレームインデクス
    - J: 追跡点インデクス
    - C: 特徴量インデクス.今回は $(x, y, z, c)$ の4次元特徴量を用いています.
- 2行目: 処理内部で用いるために,入力配列の時間長と追跡点数を取り出す
- 3行目: 補間後の配列格納用変数を初期化
- 4-17行目: 補間処理のループ
  - 5行目: $i$ 番目の追跡点配列を取り出す.取り出した後の配列は`[T, C]`形状になります.
  - 6行目: 末尾の特徴量 $(c)$ が $0$ "でない場合"に`True`,$0$ の場合に`False`となる2値マスク配列を生成.
    $c$ は追跡信頼度を示していますので,$c=0$ は追跡失敗を示します.
  - 7行目: 追跡成功フレーム数を算出
  - 8-10行目: 追跡成功フレーム数が動画フレーム数と同じ場合は補間の必要が無いので,そのまま配列を格納して次のループへ移行
  - 11-13行目: 追跡成功フレームの時間インデクス `xs` と 特徴量 `ys` を取り出す.
  - 14行目: 補間後の配列格納用変数 (一時変数) を初期化
  - 15-18行目: 特徴次元に沿ってループを回し線形補間,その後配列に値を格納

上のコードはJITコンパイル利用時も動作しますが,実はかなり強引なことをしています.
まず,引数 trackdata には jax.numpy.array 型ではなく,numpy.array 型をそのまま渡しています.
現在のJAXはデータの形状や中身に依存して処理ロジックを切り替えることが難しいです.
そのためjax.numpy.array 型を用いた場合は,8行目や11行目などの valid の値に依存した処理でコンパイルエラーとなります.

これを回避するために,入力を numpy.array 型で渡して maskvalid などの処理ロジックに関わる変数はJAXの取り扱い外としています.

行列計算を用いた実装

次のコードでは,こちらの記事と同じく,行列計算を用いた処理を実装しています.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def matrix_interp_jax(track):
    orig_shape = track.shape
    tlength = orig_shape[0]
    mask = track[:, 0, -1] != 0
    valid = mask.sum()
    if valid == tlength:
        return track

    xs = jnp.where(mask != 0, size=valid)[0]
    # ys = track.reshape([tlength, -1])[xs, :] <- can't be compiled
    ys = jnp.take(track.reshape([tlength, -1]), xs, axis=0)
    x = jnp.arange(tlength)

    # ========================================================================
    # Interpolation.
    # ========================================================================
    xs = xs.astype(ys.dtype)
    x = x.astype(ys.dtype)
    # Pad control points for extrapolation.
    xs = jnp.concatenate([jnp.array([jnp.finfo(xs.dtype).min]), xs, jnp.array([jnp.finfo(xs.dtype).max])], axis=0)
    ys = jnp.concatenate([ys[:1], ys, ys[-1:]], axis=0)

    # Compute slopes, pad at the edges to flatten.
    sloops = (ys[1:] - ys[:-1]) / jnp.expand_dims((xs[1:] - xs[:-1]), axis=-1)
    sloops = jnp.pad(sloops[:-1], [(1, 1), (0, 0)])

    # Solve for intercepts.
    intercepts = ys - sloops * jnp.expand_dims(xs, axis=-1)

    # Search for the line parameters at each input data point.
    # Create a grid of the inputs and piece breakpoints for thresholding.
    # Rely on argmax stopping on the first true when there are duplicates,
    # which gives us an index into the parameter vectors.
    idx = jnp.argmax(jnp.expand_dims(xs, axis=-2) > jnp.expand_dims(x, axis=-1), axis=-1)
    sloop = sloops[idx]
    intercept = intercepts[idx]

    # Apply the linear mapping at each input data point.
    y = sloop * jnp.expand_dims(x, axis=-1) + intercept
    y = y.astype(ys.dtype)
    y = y.reshape(orig_shape)
    return y
【コード解説】
- 引数
  - track: `[T, J, C]` 形状の追跡点配列.欠損値は $0$ 埋めされている必要があります.
    また,この関数は部位毎に追跡点配列が入力される (全追跡点,特徴量で欠損フレームが共通) ことを想定しています.
    - T: 動画フレームインデクス
    - J: 追跡点インデクス
    - C: 特徴量インデクス.今回は $(x, y, z, c)$ の4次元特徴量を用いています.
- 2-3行目: 内部処理で用いるため,`track` の形状情報を取得
- 4行目: 第0番目の追跡点の信頼度に基づいて,追跡成功フレームを示す `mask` を生成
  この関数は全追跡点,特徴量で欠損フレームが共通であることを前提としています.
- 5-7行目: 5行目で追跡成功フレーム数を取得し,線形補間が必要かを判定
- 9-12行目: 追跡成功フレームのデータ点と,補間箇所を含む全フレームのインデクス配列を生成
- 17-18行目: 計算処理のために型を `ys` に合わせて変換
- 20-21行目: `xs` と `ys` の先頭と末尾にダミー値を追加
- 24-25行目: 直線の傾き行列を算出し,先頭にダミー値を追加
- 28行目: 直線の切片行列を算出
- 34-36行目: ルックアップテーブルを作成して,全フレームのインデクスに対する直線パラメータを取得
- 39-41行目: 線形補間を実行し,型と形状を入力に合わせて返す

全部位の補間処理

全部位の追跡処理は次の関数で実装してます.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
def partsbased_interp_jax(trackdata):
    pose = trackdata[:, :33]
    lhand = trackdata[:, 33:33+21]
    rhand = trackdata[:, 33+21:33+21+21]
    face = trackdata[:, 33+21+21:]

    pose = matrix_interp_jax(pose)
    lhand = matrix_interp_jax(lhand)
    rhand = matrix_interp_jax(rhand)
    face = matrix_interp_jax(face)
    return jnp.concatenate([pose, lhand, rhand, face], axis=1)
【コード解説】
- 引数:
  - trackdata: `[T, J, C]` 形状の追跡点配列.欠損値は $0$ 埋めされている必要があります.
    こちらの追跡点には全部位のデータが含まれていることを想定しています.
- 2-5行目: 追跡点を部位毎に分割
- 7-10行目: 部位毎に線形補間を実行
- 11行目: 部位毎の追跡点を結合して返す

補間関数を用いた補間処理の実行

線形補間に必要な処理が実装できましたので,上でダウンロードしたデータを用いて処理を実行します.

次のコードでは,追跡点をロードして,関数の仕様に合わせて形状を変更しています.
追跡点データは[P, T, J, C]形状 (Pは人物インデクス) の配列ですが,人物は1名だけですのでP軸は除去しています.

1
2
3
4
5
trackdata = np.load("finger_far0_non_static.npy")
reftrack = np.load("finger_far0_non_static_interp.npy")
# Remove person axis.
trackdata = trackdata[0]
reftrack = reftrack[0]

次のコードは先程実装したsimple_interp_jaxtrackdataを入力して,補間後の追跡点newtrackを得ています.
1回目の処理呼び出しでは,処理のコンパイルが行われるため個別に時間を計測し,参照用の補間後追跡点との誤差を計測しています.
その後,処理時間計測クラスを用いて平均処理時間を表示しています.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
# JNP function-based.
# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.time()
newtrack = simple_interp_jax(trackdata)
interval = time.time() - start
print_perf_time(np.array([interval]))

diff = (reftrack - newtrack).sum()
print(f"Sum of error:{diff}")

print("Time after second call.")
target_fn = partial(simple_interp_jax, trackdata=trackdata)
pmeasure(target_fn)

print 処理の結果を示します.
1回目の処理では3.2秒程度かかっていますが,2回目以降の処理では894.8ミリ秒程度で処理が完了しています.
この処理は2回目以降もかなり時間がかかっており,現在の (強引な) 実装はJAXの想定状況にマッチしていないようです(^^;).

Time of first call.
WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Overall summary: Max 3.16399s, Min 3.16399s, Mean +/- Std 3.16399s +/-   0s
Sum of error:-6.195044477408373e-13
Time after second call.
Overall summary: Max 3.83771s, Min 884.631ms, Mean +/- Std 1.11187s +/- 372.653ms
Top 10 summary: Max 900.855ms, Min 884.631ms, Mean +/- Std 894.798ms +/- 5.34435ms

次に,入力データの形状が変わった場合の挙動を見てみます.
次のコードは上の線形補間の実行処理とほぼ同じですが,入力データの時間長が1フレーム分短くなっています.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
# JNP function-based.
# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.time()
newtrack = simple_interp_jax(trackdata[:-1])
interval = time.time() - start
print_perf_time(np.array([interval]))

diff = (reftrack[:-1] - newtrack).sum()
print(f"Sum of error:{diff}")

print("Time after second call.")
target_fn = partial(simple_interp_jax, trackdata=trackdata[:-1])
pmeasure(target_fn)

print 処理の結果を示します.
1回目は1.97秒程度,2回目以降は900.9ミリ秒程度となり,JAXでは処理の再トレーシングが起きているようです.

Time of first call.
Overall summary: Max 1.97251s, Min 1.97251s, Mean +/- Std 1.97251s +/-   0s
Sum of error:-6.195044477408373e-13
Time after second call.
Overall summary: Max 1.35904s, Min 882.56ms, Mean +/- Std 993.649ms +/- 128.017ms
Top 10 summary: Max 906.912ms, Min 882.56ms, Mean +/- Std 900.938ms +/- 6.57869ms

行列計算を用いた補間処理の実行

次のコードは先程実装したpartsbased_interp_jaxtrackdataを入力して,補間後の追跡点newtrackを得ています.
補間関数を用いた場合と同様に,1回目の処理呼び出しと2回目以降の処理呼び出しを分けて処理時間を算出しています.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
# Matrix-based.
# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.time()
newtrack = partsbased_interp_jax(trackdata)
interval = time.time() - start
print_perf_time(np.array([interval]))

diff = (reftrack - newtrack).sum()
print(f"Sum of error:{diff}")

print("Time after second call.")
target_fn = partial(partsbased_interp_jax, trackdata=trackdata)
pmeasure(target_fn)

print 処理の結果を示します.
1回目の処理では660.3ミリ秒程度かかっていますが,2回目以降の処理では9.8ミリ秒程度で処理が完了しています.

Time of first call.
Overall summary: Max 660.329ms, Min 660.329ms, Mean +/- Std 660.329ms +/-   0s
Sum of error:-6.935119145623503e-12
Time after second call.
Overall summary: Max 18.4818ms, Min 9.62976ms, Mean +/- Std 11.2587ms +/- 1.82495ms
Top 10 summary: Max 9.92951ms, Min 9.62976ms, Mean +/- Std 9.77754ms +/- 98.2471µs

補間関数を用いた場合と同様に,データ形状が変わった場合の挙動を見てみます.
次のコードは上の線形補間の実行処理とほぼ同じですが,入力データの時間長が1フレーム分短くなっています.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
# Matrix-based.
# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.time()
newtrack = partsbased_interp_jax(trackdata[:-1])
interval = time.time() - start
print_perf_time(np.array([interval]))

diff = (reftrack[:-1] - newtrack).sum()
print(f"Sum of error:{diff}")

print("Time after second call.")
target_fn = partial(partsbased_interp_jax, trackdata=trackdata[:-1])
pmeasure(target_fn)

print 処理の結果を示します.
1回目は596.7ミリ秒程度,2回目以降は9.6ミリ秒程度となっています.
JAXでは処理の再トレーシングが起きていることが改めて観測できました.

Time of first call.
Overall summary: Max 596.675ms, Min 596.675ms, Mean +/- Std 596.675ms +/-   0s
Sum of error:-6.935119145623503e-12
Time after second call.
Overall summary: Max 15.9955ms, Min 9.28973ms, Mean +/- Std 10.7838ms +/- 1.23746ms
Top 10 summary: Max 9.70015ms, Min 9.28973ms, Mean +/- Std 9.60985ms +/- 124.862µs

5.2 Define-and-Run (JIT compile) を用いた場合

ここから先は,Define-and-Run に基づいて計算グラフを作成した場合の挙動について見ていきます.

JAXでは,関数を @jax.jit デコレータで修飾することで Define-and-Run に基づいて計算グラフを作成することができます.

ここでいくつか注意点があります.
まず,基本的にはJITコンパイル可能な関数の入力は固定長でなければなりません.
今回のように入力データが可変長で,かつ,データの値に応じて内部で生成する配列形状が変わるようなケースでは @partial(jax.jit, static_argnums=(0,)) として入力 (ここでは0番目の入力) を例外的に取り扱うように指定する必要があります.

また,static_argnums で指定した入力は "Hashable" でなければいけないという制約があります.
Pythonにおいてオブジェクトが "Hashable" であるとは,オブジェクト内部に Hash値 と呼ぶ識別値を持っており,この値を基に同一オブジェクトかどうかを判定できることを言います.
numpy.ndarrayjax.numpy.Array はこの制約を満たしておらず,JITコンパイルされた関数に入力する場合は "Hashable" なラッパークラスを用いて回避する必要があります.

jax.jit についての詳細は公式のAPIドキュメントをご参照ください.
また,"Hashable"についての詳細は公式の用語説明をご参照ください.

ラッパークラスの実装

では,実装を行っていきます.
まず,次のコードで Define-and-Run で動かすために必要なモジュールを追加でインポートしています.
GenericTypeVar は "Hashable" なラッパークラスの実装で使用します.
partial は上で述べたように,jax.jit デコレータに static_argument を設定するために用います.

1
2
from typing import Generic, TypeVar
from functools import partial

次のコードでは,任意のオブジェクトを "Hashable" にするためのラッパークラスを実装しています.
こちらのコードは公式GitHubのIssueから拝借しました.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
T = TypeVar('T')      # Declare type variable

# Workaround to avoid unhashable error.
# https://github.com/google/jax/issues/4572
class HashableArrayWrapper(Generic[T]):
    def __init__(self, val: T):
        self.val = val

    def __getattribute__(self, prop):
        if prop == 'val' or prop == "__hash__" or prop == "__eq__":
            return super(HashableArrayWrapper, self).__getattribute__(prop)
        return getattr(self.val, prop)

    def __getitem__(self, key):
        return self.val[key]

    def __setitem__(self, key, val):
        self.val[key] = val

    def __hash__(self):
        return hash(self.val.tobytes())

    def __eq__(self, other):
        if isinstance(other, HashableArrayWrapper):
            return self.__hash__() == other.__hash__()

        f = getattr(self.val, "__eq__")
        return f(self, other)

補間関数を用いた実装

次のコードでは,simple_interp_jax を Define-and-Run で動作するようにしています.
中身のコードは完全に同じなのでここでは説明を省かさせていただきます.

1
2
3
@partial(jit, static_argnums=(0,))
def simple_interp_jax_jit(trackdata):
    ...

行列計算を用いた実装

次のコードでは,patsbased_interp_jax と Define-and-Run で動作するようにしています.
中身のコードは完全に同じなのでここでは説明を省かさせていただきます.

1
2
3
4
5
6
def matrix_interp_jax_jit(track):
    ...

@partial(jit, static_argnums=(0,))
def partsbased_interp_jax_jit(trackdata):
    ...

補間関数を用いた補間処理の実行

では,まずは補間関数を用いた処理を実行していきます.
処理内容は今までとほぼ同じですが,trackdataHashableArrayWrapper でラップしている点に注意してください.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
# JNP function-based.
# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.time()
newtrack = simple_interp_jax_jit(HashableArrayWrapper(trackdata))
interval = time.time() - start
print_perf_time(np.array([interval]))

diff = (reftrack - newtrack).sum()
print(f"Sum of error:{diff}")

print("Time after second call.")
target_fn = partial(simple_interp_jax_jit, trackdata=HashableArrayWrapper(trackdata))
pmeasure(target_fn)

print 処理の結果を示します.
1回目の処理では (なんと!) 31.1秒程度かかっていますが,2回目以降の処理では2.94ミリ秒程度で処理が完了しています.
ただしこの処理は2回目以降も27.2秒程度かかるケースが計測されており,やはりJAXの想定状況にマッチしていないようです.

Time of first call.
Overall summary: Max 31.0618s, Min 31.0618s, Mean +/- Std 31.0618s +/-   0s
Sum of error:-6.195044477408373e-13
Time after second call.
Overall summary: Max 27.1741s, Min 2.90357ms, Mean +/- Std 274.897ms +/- 2.70347s
Top 10 summary: Max 2.98753ms, Min 2.90357ms, Mean +/- Std 2.94133ms +/- 32.6188µs

次に,データ形状が変わった場合の挙動を見てみます.
次のコードは上の線形補間の実行処理とほぼ同じですが,入力データの時間長が1フレーム分短くなっています.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
# JNP function-based.
# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.time()
newtrack = simple_interp_jax_jit(HashableArrayWrapper(trackdata[:-1]))
interval = time.time() - start
print_perf_time(np.array([interval]))

diff = (reftrack[:-1] - newtrack).sum()
print(f"Sum of error:{diff}")

print("Time after second call.")
target_fn = partial(simple_interp_jax_jit, trackdata=HashableArrayWrapper(trackdata[:-1]))
pmeasure(target_fn)

print 処理の結果を示します.
やはり再トレーシングが起きてしまうようで,1回目は 26.0秒程度,2回目以降は3.0ミリ秒程度という結果になりました.

Time of first call.
Overall summary: Max 26.039s, Min 26.039s, Mean +/- Std 26.039s +/-   0s
Sum of error:-6.195044477408373e-13
Time after second call.
Overall summary: Max 25.4215s, Min 2.93153ms, Mean +/- Std 257.416ms +/- 2.52909s
Top 10 summary: Max 3.00317ms, Min 2.93153ms, Mean +/- Std 2.97888ms +/- 21.8667µs

行列計算を用いた補間処理の実行

次に,行列計算を用いた補間処理を実行します.
補間関数を用いた場合と同様に,1回目の処理呼び出しと2回目以降の処理呼び出しを分けて処理時間を算出しています.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
# Matrix-based.
# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.time()
newtrack = partsbased_interp_jax_jit(HashableArrayWrapper(trackdata))
interval = time.time() - start
print_perf_time(np.array([interval]))

diff = (reftrack - newtrack).sum()
print(f"Sum of error:{diff}")

print("Time after second call.")
target_fn = partial(partsbased_interp_jax_jit, trackdata=HashableArrayWrapper(trackdata))
pmeasure(target_fn)

print 処理の結果を示します.
こちらは計算グラフのコンパイル時間はあまり変わらず,1回目は361.7ミリ秒程度,2回目以降は2.0ミリ秒程度という結果になりました.

Time of first call.
Overall summary: Max 361.726ms, Min 361.726ms, Mean +/- Std 361.726ms +/-   0s
Sum of error:-2.2037927038809357e-13
Time after second call.
Overall summary: Max 349.023ms, Min 2.00301ms, Mean +/- Std 5.59448ms +/- 34.5168ms
Top 10 summary: Max 2.03389ms, Min 2.00301ms, Mean +/- Std 2.0225ms +/- 9.97664µs

次に,データ形状が変わった場合の挙動を見てみます.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
# Matrix-based.
# The 1st call may be slow because of the computation graph construction.
print(f"Time of first call.")
start = time.time()
newtrack = partsbased_interp_jax_jit(HashableArrayWrapper(trackdata[:-1]))
interval = time.time() - start
print_perf_time(np.array([interval]))

diff = (reftrack[:-1] - newtrack).sum()
print(f"Sum of error:{diff}")

print("Time after second call.")
target_fn = partial(partsbased_interp_jax_jit, trackdata=HashableArrayWrapper(trackdata[:-1]))
pmeasure(target_fn)

print 処理の結果を示します.
1回目は345.8ミリ秒程度,2回目以降は2.1ミリ秒程度という結果になりました.

Time of first call.
Overall summary: Max 345.807ms, Min 345.807ms, Mean +/- Std 345.807ms +/-   0s
Sum of error:-2.2037927038809357e-13
Time after second call.
Overall summary: Max 333.857ms, Min 2.06338ms, Mean +/- Std 5.57588ms +/- 32.9948ms
Top 10 summary: Max 2.12235ms, Min 2.06338ms, Mean +/- Std 2.10369ms +/- 18.2378µs

今回はJAXを用いた線形補間処理の実装を試みましたが,如何でしたでしょうか?
現在のJAXは動的な入力や,データの形状や中身に依存して処理ロジックを切り替えることが難しいです.
今回お見せしたように強引に動かすことは可能ですが,性能は活かすことはできませんでした.
現在のJAXを使用する場合は,制約を理解して適切な使い所を把握して利用することが重要だと感じました.

私もJAXを触ったのは今回が初めてで,まだ色々と勉強中です.
勘違いや誤りがあった場合はご指摘いただけましたら助かります.

今回紹介した話が,補間処理やJAXの取り扱いなどでお悩みの方に何か参考になれば幸いです.