PyTorch TIMMでモデル生成・一覧取得の方法(create_modelチートシート)
この記事では、PyTorch Image Models(TIMM)のcreate_model関数を使ってモデルを生成する手順を、チートシート形式でわかりやすくまとめています。また、TIMMで利用可能なモデルの一覧を取得する方法も詳しく解説します。TIMMは多彩な画像認識モデルを簡単に扱える便利なライブラリなので、その活用法を理解しておくことをお勧めします。
はじめに
この記事は、PyTorch Image Models(TIMM)のcreate_model
関数に関する簡易リファレンスとして、利用可能なパラメータや使い方をまとめたものになります。
create_model
を使えば、画像認識のバックボーンモデルを簡単に作成し、プロジェクトに組み込むことができます。
「とりあえずcreate_model
でモデルを作成し、最終段を変更する」といったように、自分でモデルを加工してカスタマイズすることも可能です。
しかし、create_model
関数には便利なパラメータが数多く用意されており、モデルを自由に修正できるため、自身によるカスタマイズを最小限に抑えることが可能です。
検索した範囲では、create_model
のパラメータについて詳しく書かれているサイトは少な買ったので、その情報をチートシート(早見表)形式で整理しました。ぜひ、活用してください。
ここに書かれた以外もいろいろな機能・使い方があリます。とりあえず、自分が使うパターンのみ列挙しました。
インストールとインポート
インストール
インストールは以下のコマンドでpip
/pip3
でインストール可能です。
pip install timm
インポート
timmのインポートは以下の通りです
import timm
create_model
チートシート
基本パターン
とりあえず、モデルを生成する場合はこれ。
model = timm.create_model('モデル名')
model = timm.create_model('resnet50')
といったパターンで呼び出します。
学習済みモデルをダウンロード
学習済みのパラメーターが設定された状態でモデルを使いたい場合に設定。基本、pretrained=True
で使う。
model = timm.create_model('モデル名', pretrained = True)
クラス数を変更
num_classes
でモデルの出力のクラス数(N)を変更する。resnet50
などはデフォルトが1000クラスなので、利用に合わせて調整する。
model = timm.create_model('モデル名', num_classes = N)
特徴量を取り出したい
ヘッドの部分を独自のものに変更したい場合は、num_classes=0
にする。
model = timm.create_model('モデル名', num_classes = 0)
resnet50
の場合は、2048の特徴量が出力されるようになる。各モデルの特徴量はnum_features
で確認可能。
Pythonコード
import torch
import timm
model = timm.create_model('resnet50', num_classes = 0)
X = torch.randn((1,3,244,244))
y = model(X)
print(y.shape)
出力
torch.Size([1, 2048])
プーリングされていない特徴量を取り出す
model = timm.create_model('モデル名', num_classes = 0, global_pool= '')
num_classes=0, global_pool=''
で、プーリングされていない特徴量が取り出せる。
Pythonコード
import torch
import timm
model = timm.create_model('resnet50', num_classes = 0, global_pool='')
X = torch.randn((1,3,244,244))
y = model(X)
print(y.shape)
出力
torch.Size([1, 2048, 8, 8])
なし・ありでモデルの最終段部分が変化します
なしの場合
(global_pool): SelectAdaptivePool2d (pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))
(fc): Identity()
ありの場合
(global_pool): SelectAdaptivePool2d (pool_type=, flatten=Identity())
(fc): Identity()
入力チャネル数を変更
in_chans
で入力チャネルをNチャネルに設定。グレー画像を入力する場合とか、他チャンネル画像を入力する場合に指定。なお、元が3チャンネルで、1チャネルの場合は3チャネルを合成したフィルタが、それ以外の場合は、0、1、2、0、1、2・・・各チャネルのフィルタが巡回した形で初期設定される(らしい。未確認です)
model = timm.create_model('モデル名', in_chans = N)
生成したモデルの情報が知りたい
モデルの情報を表示する。num_featurs
はヘッド手前の特徴量が格納されている。おそらく最も参照する情報。feature_info
は特徴マップ関連の情報。default_cfg
は各種情報。
model = timm.create_model('モデル名')
print(model.num_features)
print(model.feature_info)
print(model.default_cfg)
Pythonコード
import torch
import timm
model = timm.create_model('resnet50', num_classes = 0, global_pool='')
print(model.num_features)
print("-"*10)
print(model.feature_info)
print("-"*10)
print(model.default_cfg)
出力
2048
----------
[{'num_chs': 64, 'reduction': 2, 'module': 'act1'}, {'num_chs': 256, 'reduction': 4, 'module': 'layer1'}, {'num_chs': 512, 'reduction': 8, 'module': 'layer2'}, {'num_chs': 1024, 'reduction': 16, 'module': 'layer3'}, {'num_chs': 2048, 'reduction': 32, 'module': 'layer4'}]
----------
{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth', 'hf_hub_id': 'timm/resnet50.a1_in1k', 'architecture': 'resnet50', 'tag': 'a1_in1k', 'custom_load': False, 'input_size': (3, 224, 224), 'test_input_size': (3, 288, 288), 'fixed_input_size': False, 'interpolation': 'bicubic', 'crop_pct': 0.95, 'test_crop_pct': 1.0, 'crop_mode': 'center', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'num_classes': 1000, 'pool_size': (7, 7), 'first_conv': 'conv1', 'classifier': 'fc', 'origin_url': 'https://github.com/huggingface/pytorch-image-models', 'paper_ids': 'arXiv:2110.00476'}
resnet
の場合は、default_cfg
にarXivbの番号も入っていました。
各ブロックの特徴マップを取り出したい
feature_only=True
で特徴マップと取得する。U-netとかのバックボーンとして利用する場合に便利。
model = timm.create_model('モデル名', features_only = True)
resnet50
の場合は、5ブロック(層)あるので5つ取得できる。各ブロックがどんなサイズかは、feature_info
で確認可能。
Pythonコード
import torch
import timm
model = timm.create_model('resnet50', features_only=True)
X = torch.randn((1,3,244,244))
y = model(X)
print(len(y))
for e in y:
print(e.shape)
出力
5
torch.Size([1, 64, 122, 122])
torch.Size([1, 256, 61, 61])
torch.Size([1, 512, 31, 31])
torch.Size([1, 1024, 16, 16])
torch.Size([1, 2048, 8, 8])
特徴マップを取り出すブロックを指定したい
out_indices
で指定したブロックの情報だけ取り出す。
model = timm.create_model('モデル名', features_only = True, out_indices = [a,b,...])
Pythonコード
import torch
import timm
model = timm.create_model('resnet50', features_only=True, out_indices=[0,4])
X = torch.randn((1,3,244,244))
y = model(X)
print(len(y))
for e in y:
print(e.shape)
出力
2
torch.Size([1, 64, 122, 122])
torch.Size([1, 2048, 8, 8])
モデル一覧の取得
使えるモデルを知りたい
とりあえず、モデル名に指定できるものを取得したい。
timm.list_models()
トレーニング済みモデルが存在するものだけ知りたい
事前学習済のモデル一覧を取得したい。
timm.list_models(pretrained = True)
自分の使っているtimm 0.9.2ではpretrained=Trueとした方が、表示されるモデルが多いです。
モデルの計算量
モデルの計算量を知りたい場合は、以下の記事を参考にしてください。