PyTorch Lightningを使ったクラス分類(犬猫分類)を実践
この記事では、前の記事で作成したクラス分類(犬猫分類)をPyTorch Lightningを使って実装し直してみます。PyTorch Lightningを使うと、シンプルな記述で学習コードを書くことが可能です。PyTorch Lightningに慣れておくと後々楽だと思いますのでPyTorch Lightningによる学習コードの実装に慣れておくと良いです。
はじめに
以下の記事では、画像をダウンロードし、モデルの作成と学習までの手順について解説しました。ここでは、Pytorch Lightningを使ってこれを書き直してみます。
PyTorch Lightningとは、PyTorchの上に構築されたオープンソースの軽量フレームワークで、これを使うことでモデルのトレーニングを簡潔に記述ですることができ、コードの可読性と再利用性を高めることができる優れ物です。
コードはGoogle Colabotoryで実行できる形式としています。githubにコードを置いていますので、そちらも参考にしてください。
サンプルコード(Pytorch Lightning)へのリンク(github)
サンプルコードはGPU必須ではありませんが、GPUありで処理することをおすすめします。
Google Colabを使う場合は、ランタイム→ランタイムタイプの変更で、ハードウェアのアクセラレータにGPUを設定するのを忘れないように。GPUを利用しないとかなり実行時間がかかります。
分類タスクのコード実装
実装する分類タスクは、前回と同様「犬」と「猫」の2クラス分類です。同じものをやる方が比較しやすいと思いますので、前回と同じ題材にしました。
ここでは、処理の流れに従って説明して行きたいと思います。なお、実行環境はGoogle Colabotoryです。
ライブラリのインストール
timmとicrawlerはインストールされていないので、これをインストールします。Colabの場合は、以下のコードをコードに書きます(”!
”がついている点に注意)
!pip install timm
!pip install icrawler
!!pip install pytorch-lightning
これで、timmとicrawler、pytorch-lightningがインストールされます。
とりあえず、使いそうなライブラリをインポートしておきます。今回は、ここでインポートしていないライブラリで必要な都度インポートしてますが、最初にまとめてインポートすることをおすすめします。
import torch
import timm
import numpy as np
from icrawler.builtin import GoogleImageCrawler
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from sklearn.metrics import confusion_matrix, accuracy_score, recall_score, precision_score, f1_score
今回は、sklearn関係のインポートもここに移動させました。
画像ダウンロード
画像のダウンロードは、前回と同じです。前回の記事を読まれている場合は、読み飛ばしてください。
まず、必要な画像データセットをダウンロードします。画像のダウンロードにはicrawlerを使います。詳しくは以下の記事を参照してください。
まず、犬の画像をGoogleから取得し、images/dog
フォルダに格納します。100枚指定しますが、取得エラーなどで100枚より少なくなることがあるので注意してください。
google_crawler = GoogleImageCrawler(
storage={'root_dir': 'images/dog'})
google_crawler.crawl(keyword='dog', max_num=100)
同様にして、猫の画像をimages/cat
フォルダに格納します。こちらも数が少ない可能性があります。
google_crawler = GoogleImageCrawler(
storage={'root_dir': 'images/cat'})
google_crawler.crawl(keyword='cat', max_num=100)
次に、取得した画像を訓練用(train)と、検証用(valid)に分割します。以下では、コマンドを実行して分割をしてます。以下のコードを実行すると、1〜9枚めの画像が検証用になり、それ以外が訓練用に割り振られます。
!mkdir images/train images/valid images/train/cat images/train/dog images/valid/cat images/valid/dog
!mv images/cat/00000?.jpg images/valid/cat
!mv images/dog/00000?.jpg images/valid/dog
!mv images/cat images/train
!mv images/dog images/train
以上の操作で、訓練画像がimages/trainに、評価画像がimages/validに格納されます。
データセットの作成(dataset/dataloader)
画像を準備してしまえば、あとは定形処理です。まずデータセットを作成します。次にデータローダーを作成します。クラス分離の場合は、上記のようなフォルダ構成にしておけば、timmを使って簡単に記述することができます。
まず、必要なライブラリをインポートします。今回は、create_transform
を使ってリサイズをしています。他にもいろいろな処理が可能ですので、興味があったら調べてみてください(参考リンク)。
本来はデータ拡張(変形・色調変更、ランダムな切り出しなど)も行うのですが、今回はリサイズだけ利用しています。timmでは、基本的なデータ拡張をcreate_transform
を使って簡単に実装できます。
from timm.data import create_dataset, create_loader
from timm.data.transforms_factory import create_transform
まず、データセットを作成します。引数は、データセットの名前、rootは画像のルートフォルダ、class_mapはクラス名→番号への変換辞書です。今回は、訓練用データを./images/train
に、評価用データを./images/valid
に格納しているので、root
にはそれを指定します。また、class_map
にはdog=0, cat=1
への変換を指定します。class_map
を指定することで、それぞれのフォルダにクラス番号が割り振られます。
また、ここで画像サイズの変換等も行っています。transform=create_transform(224)
がその設定で、画像はリサイズ・クリッピングされて224×224のサイズに変換されます。また、出力はtorchのテンソル形式に変換されます。
dataset_train = create_dataset('train', root="./images/train", class_map={'dog':0, 'cat':1}, transform=create_transform(224))
dataset_valid = create_dataset('valid', root="./images/valid", class_map={'dog':0, 'cat':1}, transform=create_transform(224))
以下は、データセットの画像の確認コードサンプルです。このようにして、画像表示とラベル確認ができます。
import matplotlib.pyplot as plt
img, label = dataset_train[0]
plt.imshow(img.permute(1,2,0))
label
次にデータローダの定義です。前回はtimmのcreate_loader
を使って実装しましたが、今回はtorchのDataLoader
を使いました。
create_loader
を利用するとGPUがないときにエラーになりましたので変更しました。
batch_size
は、一度に読み出すサイズ(バッチサイズ)を指定します。shuffle
は、読み出し時に並べ替えるかどうかのフラグで、普通は訓練用データの場合True
に、訓練用データではFalse
にしておきます。
dataloader_train = DataLoader(dataset_train, batch_size=16, shuffle=True)
dataloader_valid = DataLoader(dataset_valid, batch_size=8, shuffle=False)
以下はデータローダの動作確認コードです、画像(バッチサイズ枚)がXに、ラベルがyに入力されます。ここでは、画像を表示すると大変なので、ラベルy
のみ表示しています。
# 確認
for X, y in dataloader_valid:
print(y)
以上でデータセットの準備は完了です。
パラメータ設定
パラメータ設定をしておきます。今回のコードでは、epoch数と、device
だけです。device
は、create_loaderがGPUが存在する前提となっているので、”cuda“に固定しています。
epoch数はとりあえず、10に設定しました。
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LightningModuleを定義
ここから、Pytorch Lightningらしい部分になります。
まずは、LightningModuleを定義します。基本的には、訓練ループ、検証ループで行っていた処理をそれぞれのメソッドに記述していきます。
なお、Lightningでは「GPUあり・なし」などを自動判定してくれるので、.to(device)
を書く必要はありません。
class myModule(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = timm.create_model('resnet18', pretrained=True, num_classes=2)
self.loss_fn = torch.nn.CrossEntropyLoss()
self.training_step_outputs = []
self.validation_step_outputs = []
def training_step(self, batch, batch_idx):
x, y = batch
pred = self.model(x)
loss = self.loss_fn(pred, y)
self.training_step_outputs.append(loss)
return loss
def on_train_epoch_end(self):
epoch_mean = torch.stack(self.training_step_outputs).mean()
self.log("train_loss", epoch_mean, prog_bar=True)
self.training_step_outputs.clear()
def validation_step(self, batch, batch_idx):
x, y = batch
pred = self.model(x)
loss = self.loss_fn(pred, y)
self.validation_step_outputs.append(loss)
return loss
def on_validation_epoch_end(self):
epoch_mean = torch.stack(self.validation_step_outputs).mean()
self.log("valid_loss", epoch_mean, prog_bar=True)
self.validation_step_outputs.clear()
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=0.001, weight_decay=0.0001)
以下、それぞれのメソッドの役割と、処理内容になります。
__init__
初期化ブロックです、ここではtimmを用いて、resnet18
の学習済みモデルを、出力クラス =2クラスにして読み込んでいます。また、損失関数の定義を行っています。残りの2つの変数は、ログを保存するためのものです。
training_step
訓練データへの1バッチ分の処理を記載します。backwrodとかoptimizerの処理を書かなくて良い部分が異なります。
on_train_epoch_end
訓練の毎EPOCH終了時に呼び出されます。ここでは、training_step
で保存していたlossを集計して表示させています。prog_bar=True
を設定しているのでプログレスバーに表示されます。
validation_step
検証データへの1バッチ分の処理を記載します。
on_validation_epoch_end
検証の毎EPOCH終了時に呼び出されます。こちらも、lossの表示を行っています。
configure_optimizers
最適化アルゴリズムを設定しています。前回と同じくAdamを利用しました。
学習
Lightningえば、学習コードがとても短くなります。CPU/GPUへの対応は自動ですし、パラメータ指定するだけで色々な設定を自動で行ってくれます。この、学習コードにまつわる、煩雑な部分を自動・半自動で行ってくれるのがLightningのメリットになります。
module = myModule()
trainer = pl.Trainer(max_epochs=num_epochs)
trainer.fit(model=module,
train_dataloaders = dataloader_train,
val_dataloaders = dataloader_valid)
model.train()
、model.eval()
も自動設定してくれます
Mixed Precisionとかの設定も簡単に行うことができます
評価
個人的な趣味ですが、推論部分ではLightningを使いません。どうするかというと、学習したモデルを読み出します。コードは以下になります。
model = module.model
今回は、module.model=timm.create_model('resnet18' ... )
と定義していましたので、resnet18
の部分だけ切り出して取り出すことができます。
私のコードでは、modelをpytorchので定義して、それをLightningの__init__に引数として渡すようなものが多いです。modelだけ後で使いたいので。
modelに代入したら、あとは、普段と同じです。
学習が終わった後、評価データを使用してモデルの評価結果を出力します。outputs
は1枚の画像に対して2つの値(犬と猫のスコア)が入っているため、argmax()
関数を使って犬と猫のスコアがより大きい方のインデックスを選択し、それをy_pred
として格納します。これと真値(y_gt
)を比較することで評価、正解率などを算出することができます。
from tqdm import tqdm
model.to(device)
model.eval()
y_pred = []
y_gt = []
for batch in tqdm(dataloader_valid):
inputs, targets = batch
with torch.no_grad() :
outputs = model(inputs.to(device))
y_gt += targets.tolist()
y_pred += outputs.argmax(axis=1).tolist()
指標計算
y_pred
、y_gt
を使って、評価指標を計算してみます。これには、scikit-learn
を利用します。まず、評価関数(confusion_matrix, accuracy_score, recall_score, precision_score, f1_score
)をインポートします。
from sklearn.metrics import confusion_matrix, accuracy_score, recall_score, precision_score, f1_score
あとは、それぞれの関数を呼び出すだけで指標の計算を行うことができます。
混合行列(confution matrix)
confusion_matrix(y_gt, y_pred)
正解率(Accuracy)、適合率(Precision)、再現率(Recall)、F値(F-measure)
acc, recall, prec, fs = accuracy_score(y_gt, y_pred), recall_score(y_gt, y_pred), precision_score(y_gt, y_pred), f1_score(y_gt, y_pred)
print(f"acc={acc}, recall={recall}, precition = {prec}, f-score={fs}")
それぞれの指標に関しては、以下などが参考になります。
間違った画像を確認
間違った画像は以下のコードで確認できます。
for idx, (x, y) in enumerate(zip(y_pred, y_gt)):
if x != y :
img, label = dataset_valid[idx]
plt.imshow(img.transpose(1,2,0))
plt.show()
少ないエポック数でも予想以上に学習が進んでいると感じるかもしれません。これは、pretrained=True
で事前学習済みモデルを読み込んでいるからです。事前学習モデルは犬と猫の分類を事前に学習しているため、ここでの学習は微調整(特に、ラベルを0と1に変更する部分)が主な作業となっている可能性が大きいです。
まとめ
以上、Pytorch Lightningを使ってクラス分類の学習コードを書いてみました。このコードだとLightningの良さが伝わらなかったかもしれません。Lightningのいいところは、GPU/CPUの切り替えや、32bit/16bitの切り替えなどのメイン以外の部分のコードを書かなくて良くなると言う部分です。今回のものもlogは自動的に保存されていきますし、設定すればtensorboardなどで学習のログを確認することもできるようになります。また、マルチGPUなどへの対応も簡単になります。
また、pytorchのコードから簡単にコンバートすることができます(今回もLightningの機能を使い切っているわけではなく、一部だけポートした感じです)。
かなり便利なので使ってみてください。
メインのコードに集中し、周辺コードを書かなくて良くなると言うのがPytorch Lightningのメリット
最初はpytorchで書いて、パラメータ探索などの試行錯誤をするときに、Lightningに持っていくって言うことも結構多いね