機械学習
記事内に商品プロモーションを含む場合があります

PyTorch-TIMMでモデル作成、モデル一覧を取得する方法(create_modelチートシート)

timmのcreate_modelチートシート
tadanori

この記事では、TIMM(PyTorch Image Models)のモデルを生成する方法と、利用できるモデル一覧の調べ方などをまとめました。

実際の使用例につていは、以下の記事を参考にしてください。

【pytoch】 timmでクラス分類(犬猫分類)にトライ
【pytoch】 timmでクラス分類(犬猫分類)にトライ

はじめに

timm(PyTorch Image Models)のcreate_modelのチートシートです。このライブラリはかなり便利で、いろいろなパターンでバックボーンとなるモデルを作ることができます。

結構な頻度で使い方を検索するのですが、まとまって書かれたサイトを見つけられなかったので、チートシート(早見表)を作っておきます。

ここに書かれた以外もいろいろな機能・使い方があリます。とりあえず、自分が使うパターンのみ列挙しました。

インストールとインポート

インストール

インストールは以下のコマンドでpip/pip3でインストール可能です。

pip install timm

インポート

timmのインポートは以下の通りです

import timm

create_modelチートシート

基本パターン

とりあえず、モデルを生成する場合はこれ。

model = timm.create_model('モデル名')
model = timm.create_model('resnet50')

といったパターンで呼び出します。

学習済みモデルをダウンロード

学習済みのパラメーターが設定された状態でモデルを使いたい場合に設定。基本、pretrained=Trueで使う。

model = timm.create_model('モデル名', pretrained = True)

クラス数を変更

num_classesでモデルの出力のクラス数(N)を変更する。resnet50などはデフォルトが1000クラスなので、利用に合わせて調整する。

model = timm.create_model('モデル名', num_classes = N)

特徴量を取り出したい

ヘッドの部分を独自のものに変更したい場合は、num_classes=0にする。

model = timm.create_model('モデル名', num_classes = 0)

resnet50の場合は、2048の特徴量が出力されるようになる。各モデルの特徴量はnum_featuresで確認可能。

コード:
import torch
import timm
model = timm.create_model('resnet50', num_classes = 0)
X = torch.randn((1,3,244,244))
y = model(X)
print(y.shape)

出力
torch.Size([1, 2048])

プーリングされていない特徴量を取り出す

model = timm.create_model('モデル名', num_classes = 0, global_pool= '')

num_classes=0, global_pool=''で、プーリングされていない特徴量が取り出せる。

コード:
import torch
import timm
model = timm.create_model('resnet50', num_classes = 0, global_pool='')
X = torch.randn((1,3,244,244))
y = model(X)
print(y.shape)

出力
torch.Size([1, 2048, 8, 8])
global_pool=” あり・なしでの変化

なし・ありでモデルの最終段部分が変化します

なしの場合

(global_pool): SelectAdaptivePool2d (pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))
(fc): Identity()

ありの場合

(global_pool): SelectAdaptivePool2d (pool_type=, flatten=Identity())
(fc): Identity()

入力チャネル数を変更

in_chansで入力チャネルをNチャネルに設定。グレー画像を入力する場合とか、他チャンネル画像を入力する場合に指定。なお、元が3チャンネルで、1チャネルの場合は3チャネルを合成したフィルタが、それ以外の場合は、0、1、2、0、1、2・・・各チャネルのフィルタが巡回した形で初期設定される(らしい。未確認です)

model = timm.create_model('モデル名', in_chans = N)

生成したモデルの情報が知りたい

モデルの情報を表示する。num_featursはヘッド手前の特徴量が格納されている。おそらく最も参照する情報。feature_infoは特徴マップ関連の情報。default_cfgは各種情報。

model = timm.create_model('モデル名')
print(model.num_features)
print(model.feature_info)
print(model.default_cfg)
コード:
import torch
import timm
model = timm.create_model('resnet50', num_classes = 0, global_pool='')
print(model.num_features)
print("-"*10)
print(model.feature_info)
print("-"*10)
print(model.default_cfg)

出力
2048
----------
[{'num_chs': 64, 'reduction': 2, 'module': 'act1'}, {'num_chs': 256, 'reduction': 4, 'module': 'layer1'}, {'num_chs': 512, 'reduction': 8, 'module': 'layer2'}, {'num_chs': 1024, 'reduction': 16, 'module': 'layer3'}, {'num_chs': 2048, 'reduction': 32, 'module': 'layer4'}]
----------
{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth', 'hf_hub_id': 'timm/resnet50.a1_in1k', 'architecture': 'resnet50', 'tag': 'a1_in1k', 'custom_load': False, 'input_size': (3, 224, 224), 'test_input_size': (3, 288, 288), 'fixed_input_size': False, 'interpolation': 'bicubic', 'crop_pct': 0.95, 'test_crop_pct': 1.0, 'crop_mode': 'center', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'num_classes': 1000, 'pool_size': (7, 7), 'first_conv': 'conv1', 'classifier': 'fc', 'origin_url': 'https://github.com/huggingface/pytorch-image-models', 'paper_ids': 'arXiv:2110.00476'}

resnetの場合は、default_cfgにarXivbの番号も入っていました。

各ブロックの特徴マップを取り出したい

feature_only=Trueで特徴マップと取得する。U-netとかのバックボーンとして利用する場合に便利。

model = timm.create_model('モデル名', features_only = True)

resnet50の場合は、5ブロック(層)あるので5つ取得できる。各ブロックがどんなサイズかは、feature_infoで確認可能。

コード:
import torch
import timm
model = timm.create_model('resnet50', features_only=True)
X = torch.randn((1,3,244,244))
y = model(X)
print(len(y))
for e in y:
  print(e.shape)

出力
5
torch.Size([1, 64, 122, 122])
torch.Size([1, 256, 61, 61])
torch.Size([1, 512, 31, 31])
torch.Size([1, 1024, 16, 16])
torch.Size([1, 2048, 8, 8])

特徴マップを取り出すブロックを指定したい

out_indicesで指定したブロックの情報だけ取り出す。

model = timm.create_model('モデル名', features_only = True, out_indices = [a,b,...])
コード:
import torch
import timm
model = timm.create_model('resnet50', features_only=True, out_indices=[0,4])
X = torch.randn((1,3,244,244))
y = model(X)
print(len(y))
for e in y:
  print(e.shape)

出力
2
torch.Size([1, 64, 122, 122])
torch.Size([1, 2048, 8, 8])

モデル一覧の取得

使えるモデルを知りたい

とりあえず、モデル名に指定できるものを取得したい。

timm.list_models()

トレーニング済みモデルが存在するものだけ知りたい

事前学習済のモデル一覧を取得したい。

timm.list_models(pretrained = True)

自分の使っているtimm 0.9.2ではpretrained=Trueとした方が、表示されるモデルが多いです。

おすすめ書籍

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

記事URLをコピーしました