目次
こんにちは.高山です.
今回は Neural network (NN) モデルを実装する際の小技を紹介したいと思います.
具体的には,Pydantic というライブラリを用いて NNモデルのハイパーパラメータをまとめることで,NNモデルのインタフェースを簡潔にする方法を紹介します.
更新履歴 (大きな変更のみ記載しています)
- 2024/10/19: 第4節 - 辞書からクラスのインスタンスへにおいて,
model_validate()
にインスタンスを渡した場合は,model_post_init()
が呼ばれないことを注記しました.
1. やりたいこと
NNモデルを実装する際は,下記のようにハイパーパラメータを引数で渡して実装することが多いと思います.
1 2 3 4 |
|
この方法はシンプルですし,色々な事情 (ドキュメンテーションなど) から上記の実装形態を選ばざるを得ない場合もあります.
ただしこの実装方法の場合,NNモデルが複雑になってくるにつれて,下記のように大量のハイパーパラメータが引数に並ぶようになります.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
|
どうでしょう? ゲンナリしてきませんか? (^^;)
上記の状況を説明すると図1のようになります.
NNモデルは内部に様々なレイヤや処理ブロックを持ちます.
処理ブロックは特定の機能や処理でレイヤ郡をまとめたかたまりを指します.
レイヤや処理ブロックはそれぞれ設定値を持つため,呼び出し側から制御するためには親モデルの引数を介して値を渡す必要があります.
結果として,親モデルの引数は子レイヤのパラメータを全て並べたような形になります.
このような実装を避けるための方法の一つとして,図2に示すようにパラメータ郡をクラス化する方法があります.
図2の左側はモデルのパラメータを変数として保持するクラスです.
このクラスはデータを整理して持つことが主な役割で,計算処理などは基本的に行いません.
図2の右側はモデル本体で,引数としてパラメータのクラスを受け取り自身をインスタンス化します.
このように設計をすると,次のコードのようにモデルの引数を簡潔にすることができます.
1 2 3 4 5 |
|
このような設計を実装する方法は,辞書型の変数,通常のクラス,Dataclass
などいくつかありますが,今回は Pydantic というライブラリを用いた実装方法を紹介したいと思います.
2. Pydanticとは?
Pydantic はデータ検証ライブラリ (Data validation library) の一つです.
データ検証とは,データの型や値の範囲が事前に定義した仕様を満たしているかをチェックすることを意味します.
Pydantic ではクラスを定義する際に下記のように型や値の範囲を定義することで,インスタンス化時に自動的に値をチェックすることができます.
1 2 3 4 5 6 7 8 9 |
|
BaseModel
は Pydantic のデータ検証機能を備えた基底クラスです.
自分が実装するクラスで BaseModel
を継承することで,データ検証機能が働くようになります.
本記事では基本的な使い方だけを紹介します.
全機能については,公式サイトの Models の項目をご参照ください.
型アノテーションによる型チェック
in_channels: int
のように,変数名右側の指定は "型アノテーション" と呼びます.
ここでは変数 in_channels
が int
型であることを示しています.
型アノテーション自体はコードを分かりやすくするための記述法なのですが,Pydantic では型アノテーションを活用して入力値をチェックすることができます.
まずは普通にクラスをインスタンス化してみます.
1 2 |
|
in_channels=64 out_channels=128 activation='relu'
問題無く動作しますね.
数値を文字列で与えた場合はどうでしょうか.
1 2 3 4 5 |
|
in_channels=64 out_channels=128 activation='relu'
文字列が数値に変換可能な場合は,Pydantic が自動で変換してくれます.
便利ですね(^^).
では,数値じゃない文字列を与えた場合はどうでしょう.
1 2 3 4 5 |
|
1 validation error for ModelParams
out_channels
Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='abc', input_type=str]
For further information visit https://errors.pydantic.dev/2.9/v/int_parsing
in_channels=64 out_channels=128 activation='relu'
この場合は,out_channels
が数値に変換できないためエラーが発生します.
この機能によって,不正な入力値によるバグを早期に発見することができます.
Fieldによる値チェック
Pydantic の Field
という機能を使うと,より細かな制御ができるようになります.
ここでは値の範囲を制限する方法を紹介します.
全機能については,公式サイトの Field の項目をご参照ください.
変数 activation
については,下記のように Field
を定義しています.
activation: str = Field(default="relu", pattern=r"relu|tanh|sigmoid")
default
はデフォルト値を示し,pattern
は入力可能な文字列を正規表現で指定します.
上記の例は,activation
のデフォルト値が relu
で,入力可能な値は relu
, tanh
, sigmoid
のいずれかであることを意味します.
試しに範囲外の値を入力してみましょう.
1 2 3 4 5 |
|
1 validation error for ModelParams
activation
String should match pattern 'relu|tanh|sigmoid' [type=string_pattern_mismatch, input_value='silu', input_type=str]
For further information visit https://errors.pydantic.dev/2.9/v/string_pattern_mismatch
in_channels=64 out_channels=128 activation='relu'
silu
は入力可能な値ではないので,エラーが発生します.
この機能によって,下記のようにモデル内で細かく値のチェックをする必要が無くなります.
1 2 3 4 5 6 7 8 |
|
3. Pydantic モデルを用いてモデルをリファクタ
Pydantic を利用してモデルをリファクタしてみます.
ここでは,手話入門記事 第二回で説明した,シンプルな孤立手話単語認識を例として用います.
このモデルは Linear層,活性化関数,および Global average pooling層からなっており,元々のコードは下記のように実装されていました.
1 2 3 4 5 6 7 8 9 10 |
|
1 2 3 4 5 6 7 8 |
|
ここではインスタンス化に関係する部分だけを記載しています.
3.1 GPoolRecognitionHeadのリファクタリング
ではリファクタリングをしてみます.
まず,下記のコードで GPoolRecognitionHead
をリファクタリングします.
1 2 3 4 5 6 |
|
1 2 3 4 5 6 7 8 9 |
|
ここはそこまで難しくなく,設定値をそのまま与えてインスタンス化しているだけです.
GPoolRecognitionHeadSettings
の build_layer()
で,設定値からレイヤのインスタンス化ができるようにしています.
(インスタンス化する際に一々 GPoolRecognitionHead(settings)
とやるのが面倒なので (^^;))
3.2 SimpleISLRのリファクタリング
次に,下記のコードで SimpleISLR
用のパラメータクラスを実装します.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
|
先ほど実装した GPoolRecognitionHeadSettings
を変数として定義することで,パラメータを入れ子にすることができます.
パラメータ用のクラスを用いると,SimpleISLR
は下記のように実装できます.
1 2 3 4 5 6 7 8 9 10 11 12 |
|
元々のモデルがシンプルなので効果があまり実感できないかもしれないですね (^^;).
より複雑なモデルなってくるとかなり簡潔な印象を持つと思います.
default_factoryについて
クラスのように複雑な変数に対して初期値を与える場合は,default_factory
に初期化関数を定義します.
少し細かな話ですが,default_factory
に入力する値は Callable
オブジェクトでないといけません.
これは,
settings = GPoolRecognitionHeadSettings()
とした場合に,関数のように settings()
で呼び出して何かの処理を行えるオブジェクトであることを意味します.
クラスの場合は,__call__()
メソッドを実装するとインスタンスが Callable
オブジェクトになります.
各クラスに __call__()
メソッドを実装するのは面倒なので,ここでは lambda
というその場で関数を定義する (無名関数と言います) 機能を使っています.
上の lambda: GPoolRecognithonHeadSettings()
は下記の関数と同じような働きをします.
def func():
return GPoolRecognitionHeadSettings()
model_post_initについて
SimpleISLR
では,Linear
の出力チャネルと,GPoolRecognitionHead
の入力チャネルが連動しています.
もちろん,インスタンス化時に注意深く値を与えればよいのですが,model_post_init()
関数を利用すると連動した値を制御しやすくなります.
model_post_init()
は Pydantic の BaseModel
で定義されているメソッドで,この関数はインスタンス化後に行う追加処理を実装するために用意されています.
ここでは,
def model_post_init(self, __context):
self.head_settings.in_channels = self.inter_channels
self.head_settings.out_channels = self.out_channels
# Propagate.
self.head_settings.model_post_init(__context)
のようにして,head_settings
に冗長な値を入力しなくても自動的に値が決まるようにしています.
head_settings.model_post_init()
は今回の場合は省略可能です
パラメータが複雑になってくると子パラメータ内でも連動するようなパラメータが出てくるため,実装スタイルとして必ず呼び出すようにしています.
なお,__context
は BaseModel
側で定義されているので,未使用でも引数に入れなくてはいけません.
(公式サイトでもほとんど説明がありませんが,with
文や decorator 経由でインスタンス化する際に使用する値のようです)
4. 辞書型との相互変換
Pydantic の BaseModel
には便利な機能が多数あります.
ここでは辞書型との相互変換機能を紹介します.
この機能は,学習に使用したパラメータをファイルに保存する場合や,逆にファイルからパラメータを読み込んで実験条件を再現する場合などに使えます.
では,実際に試してみましょう.
先ほど実装した SimpleISLRSettings
をインスタンス化してみます.
params = SimpleISLRSettings(in_channels=64, inter_channels=128, out_channels=100)
print(params)
SimpleISLRSettings(in_channels=64, inter_channels=128, out_channels=100,
head_settings=GPoolRecognitionHeadSettings(in_channels=128, out_channels=100))
クラスのインスタンスから辞書へ
Pydantic のクラスのインスタンスから辞書への変換は model_dump()
メソッドを呼び出すことで行えます.
dict_params = params.model_dump()
print(type(dict_params))
print(dict_params)
<class 'dict'>
{'in_channels': 64, 'inter_channels': 128, 'out_channels': 100, 'head_settings': {'in_channels': 128, 'out_channels': 100}}
辞書からクラスのインスタンスへ
辞書からクラスのインスタンスへの変換は,クラスメソッド model_validate()
を呼び出すことで行えます.
動きが分かりやすいように,まず値の一部を変更してみます.
dict_params["out_channels"] = 200
print(dict_params)
{'in_channels': 64, 'inter_channels': 128, 'out_channels': 200,
'head_settings': {'in_channels': 128, 'out_channels': 100}}
辞書データをいじっただけなので,head_settings.out_channels
は古い値のままである点に注意してください.
では変換してみます.
update_params = SimpleISLRSettings.model_validate(dict_params)
print(type(update_params))
print(update_params)
<class '__main__.SimpleISLRSettings'>
in_channels=64 inter_channels=128 out_channels=200
head_settings=GPoolRecognitionHeadSettings(in_channels=128, out_channels=200)
model_validate()
はデータを検証した上でインスタンスを返します.
内部で model_post_init()
が呼び出されるので,head_settings.out_channels
の値が上書きされています.
なお,model_validate()
にインスタンスを渡した場合は model_post_init()
は呼び出されないようなので,注意してください.
今回は,Pydantic を用いて NNモデルのハイパーパラメータをまとめる方法を紹介しましたが,如何でしょうか?
Pydantic 自体は深層学習とは関係が無いので,色々なタスクで利用できます.
また,便利なライブラリやフレームワークは他にもありますのでいずれ紹介できたらいいなと思っています.
今回紹介した話が,同じようなことで悩んでいる方に何か参考になれば幸いです.