機械学習
記事内に商品プロモーションを含む場合があります

PyTorch Lightningを使ったクラス分類(犬猫分類)を実践

アイキャッチ画像
Aru

この記事では、前の記事で作成したクラス分類(犬猫分類)をPyTorch Lightningを使って実装し直してみます。PyTorch Lightningを使うと、シンプルな記述で学習コードを書くことが可能です。PyTorch Lightningに慣れておくと後々楽だと思いますのでPyTorch Lightningによる学習コードの実装に慣れておくと良いです。

はじめに

以下の記事では、画像をダウンロードし、モデルの作成と学習までの手順について解説しました。ここでは、Pytorch Lightningを使ってこれを書き直してみます。

PyTorch Lightningとは、PyTorchの上に構築されたオープンソースの軽量フレームワークで、これを使うことでモデルのトレーニングを簡潔に記述ですることができ、コードの可読性と再利用性を高めることができる優れ物です。

あわせて読みたい
PyTorch+TIMMでクラス分類(犬猫分類)にチャレンジ
PyTorch+TIMMでクラス分類(犬猫分類)にチャレンジ

コードはGoogle Colabotoryで実行できる形式としています。githubにコードを置いていますので、そちらも参考にしてください。

サンプルコード(Pytorch Lightning)へのリンク(github)

Google Colabotoryで実行する場合の注意点

サンプルコードはGPU必須ではありませんが、GPUありで処理することをおすすめします。

Google Colabを使う場合は、ランタイム→ランタイムタイプの変更で、ハードウェアのアクセラレータにGPUを設定するのを忘れないように。GPUを利用しないとかなり実行時間がかかります。

分類タスクのコード実装

実装する分類タスクは、前回と同様「犬」と「猫」の2クラス分類です。同じものをやる方が比較しやすいと思いますので、前回と同じ題材にしました。

ここでは、処理の流れに従って説明して行きたいと思います。なお、実行環境はGoogle Colabotoryです。

ライブラリのインストール

timmとicrawlerはインストールされていないので、これをインストールします。Colabの場合は、以下のコードをコードに書きます(”!”がついている点に注意)

!pip install timm
!pip install icrawler
!!pip install pytorch-lightning

これで、timmとicrawler、pytorch-lightningがインストールされます。

とりあえず、使いそうなライブラリをインポートしておきます。今回は、ここでインポートしていないライブラリで必要な都度インポートしてますが、最初にまとめてインポートすることをおすすめします

import torch
import timm
import numpy as np
from icrawler.builtin import GoogleImageCrawler

from torch.utils.data import DataLoader

import pytorch_lightning as pl
from sklearn.metrics import confusion_matrix, accuracy_score, recall_score, precision_score, f1_score

今回は、sklearn関係のインポートもここに移動させました。

画像ダウンロード

画像のダウンロードは、前回と同じです。前回の記事を読まれている場合は、読み飛ばしてください。

まず、必要な画像データセットをダウンロードします。画像のダウンロードにはicrawlerを使います。詳しくは以下の記事を参照してください。

あわせて読みたい
iCrawlerで画像収集を自動化する|Pythonでクローリングする方法
iCrawlerで画像収集を自動化する|Pythonでクローリングする方法

まず、犬の画像をGoogleから取得し、images/dogフォルダに格納します。100枚指定しますが、取得エラーなどで100枚より少なくなることがあるので注意してください

google_crawler = GoogleImageCrawler(
    storage={'root_dir': 'images/dog'})
google_crawler.crawl(keyword='dog', max_num=100)

同様にして、猫の画像をimages/catフォルダに格納します。こちらも数が少ない可能性があります。

google_crawler = GoogleImageCrawler(
    storage={'root_dir': 'images/cat'})
google_crawler.crawl(keyword='cat', max_num=100)

次に、取得した画像を訓練用(train)と、検証用(valid)に分割します。以下では、コマンドを実行して分割をしてます。以下のコードを実行すると、1〜9枚めの画像が検証用になり、それ以外が訓練用に割り振られます。

!mkdir images/train images/valid images/train/cat images/train/dog images/valid/cat images/valid/dog
!mv images/cat/00000?.jpg images/valid/cat
!mv images/dog/00000?.jpg images/valid/dog
!mv images/cat images/train
!mv images/dog images/train

以上の操作で、訓練画像がimages/trainに、評価画像がimages/validに格納されます。

データセットの作成(dataset/dataloader)

画像を準備してしまえば、あとは定形処理です。まずデータセットを作成します。次にデータローダーを作成します。クラス分離の場合は、上記のようなフォルダ構成にしておけば、timmを使って簡単に記述することができます。

まず、必要なライブラリをインポートします。今回は、create_transformを使ってリサイズをしています。他にもいろいろな処理が可能ですので、興味があったら調べてみてください(参考リンク)。

本来はデータ拡張(変形・色調変更、ランダムな切り出しなど)も行うのですが、今回はリサイズだけ利用しています。timmでは、基本的なデータ拡張をcreate_transformを使って簡単に実装できます。

from timm.data import create_dataset, create_loader
from timm.data.transforms_factory import create_transform

まず、データセットを作成します。引数は、データセットの名前、rootは画像のルートフォルダ、class_mapはクラス名→番号への変換辞書です。今回は、訓練用データを./images/trainに、評価用データを./images/validに格納しているので、rootにはそれを指定します。また、class_mapにはdog=0, cat=1への変換を指定します。class_mapを指定することで、それぞれのフォルダにクラス番号が割り振られます。

また、ここで画像サイズの変換等も行っています。transform=create_transform(224)がその設定で、画像はリサイズ・クリッピングされて224×224のサイズに変換されます。また、出力はtorchのテンソル形式に変換されます

dataset_train = create_dataset('train', root="./images/train", class_map={'dog':0, 'cat':1}, transform=create_transform(224))
dataset_valid = create_dataset('valid', root="./images/valid", class_map={'dog':0, 'cat':1}, transform=create_transform(224))

以下は、データセットの画像の確認コードサンプルです。このようにして、画像表示とラベル確認ができます。

import matplotlib.pyplot as plt
img, label = dataset_train[0]
plt.imshow(img.permute(1,2,0))
label

次にデータローダの定義です。前回はtimmのcreate_loaderを使って実装しましたが、今回はtorchのDataLoaderを使いました。

create_loaderを利用するとGPUがないときにエラーになりましたので変更しました。

batch_sizeは、一度に読み出すサイズ(バッチサイズ)を指定します。shuffleは、読み出し時に並べ替えるかどうかのフラグで、普通は訓練用データの場合Trueに、訓練用データではFalseにしておきます。

dataloader_train = DataLoader(dataset_train, batch_size=16, shuffle=True)
dataloader_valid = DataLoader(dataset_valid, batch_size=8, shuffle=False)

以下はデータローダの動作確認コードです、画像(バッチサイズ枚)がXに、ラベルがyに入力されます。ここでは、画像を表示すると大変なので、ラベルyのみ表示しています。

# 確認
for X, y in dataloader_valid:
  print(y)

以上でデータセットの準備は完了です。

パラメータ設定

パラメータ設定をしておきます。今回のコードでは、epoch数と、deviceだけです。deviceは、create_loaderがGPUが存在する前提となっているので、”cuda“に固定しています。

epoch数はとりあえず、10に設定しました。

num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

LightningModuleを定義

ここから、Pytorch Lightningらしい部分になります。

まずは、LightningModuleを定義します。基本的には、訓練ループ、検証ループで行っていた処理をそれぞれのメソッドに記述していきます。

なお、Lightningでは「GPUあり・なし」などを自動判定してくれるので、.to(device)を書く必要はありません。

class myModule(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = timm.create_model('resnet18', pretrained=True, num_classes=2)
        self.loss_fn = torch.nn.CrossEntropyLoss()
        self.training_step_outputs = []
        self.validation_step_outputs = []

    def training_step(self, batch, batch_idx):
        x, y = batch
        pred = self.model(x)
        loss = self.loss_fn(pred, y)
        self.training_step_outputs.append(loss)
        return loss

    def on_train_epoch_end(self):
        epoch_mean = torch.stack(self.training_step_outputs).mean()
        self.log("train_loss", epoch_mean, prog_bar=True)
        self.training_step_outputs.clear()

    def validation_step(self, batch, batch_idx):
        x, y = batch
        pred = self.model(x)
        loss = self.loss_fn(pred, y)
        self.validation_step_outputs.append(loss)
        return loss

    def on_validation_epoch_end(self):
        epoch_mean = torch.stack(self.validation_step_outputs).mean()
        self.log("valid_loss", epoch_mean, prog_bar=True)
        self.validation_step_outputs.clear()

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.001, weight_decay=0.0001)

以下、それぞれのメソッドの役割と、処理内容になります。

__init__

初期化ブロックです、ここではtimmを用いて、resnet18の学習済みモデルを、出力クラス =2クラスにして読み込んでいます。また、損失関数の定義を行っています。残りの2つの変数は、ログを保存するためのものです。

training_step

訓練データへの1バッチ分の処理を記載します。backwrodとかoptimizerの処理を書かなくて良い部分が異なります。

on_train_epoch_end

訓練の毎EPOCH終了時に呼び出されます。ここでは、training_stepで保存していたlossを集計して表示させています。prog_bar=Trueを設定しているのでプログレスバーに表示されます。

validation_step

検証データへの1バッチ分の処理を記載します。

on_validation_epoch_end

検証の毎EPOCH終了時に呼び出されます。こちらも、lossの表示を行っています。

configure_optimizers

最適化アルゴリズムを設定しています。前回と同じくAdamを利用しました。

学習

Lightningえば、学習コードがとても短くなります。CPU/GPUへの対応は自動ですし、パラメータ指定するだけで色々な設定を自動で行ってくれます。この、学習コードにまつわる、煩雑な部分を自動・半自動で行ってくれるのがLightningのメリットになります。

module = myModule()
trainer = pl.Trainer(max_epochs=num_epochs)

trainer.fit(model=module,
            train_dataloaders = dataloader_train,
            val_dataloaders = dataloader_valid)

model.train()model.eval()も自動設定してくれます

Mixed Precisionとかの設定も簡単に行うことができます

評価

個人的な趣味ですが、推論部分ではLightningを使いません。どうするかというと、学習したモデルを読み出します。コードは以下になります。

model = module.model

今回は、module.model=timm.create_model('resnet18' ... )と定義していましたので、resnet18の部分だけ切り出して取り出すことができます。

私のコードでは、modelをpytorchので定義して、それをLightningの__init__に引数として渡すようなものが多いです。modelだけ後で使いたいので。

modelに代入したら、あとは、普段と同じです。

学習が終わった後、評価データを使用してモデルの評価結果を出力します。outputsは1枚の画像に対して2つの値(犬と猫のスコア)が入っているため、argmax()関数を使って犬と猫のスコアがより大きい方のインデックスを選択し、それをy_predとして格納します。これと真値(y_gt)を比較することで評価、正解率などを算出することができます。

from tqdm import tqdm
model.to(device)
model.eval()
y_pred = []
y_gt = []
for batch in tqdm(dataloader_valid):
    inputs, targets = batch
    with torch.no_grad() :
      outputs = model(inputs.to(device))
    y_gt += targets.tolist()
    y_pred += outputs.argmax(axis=1).tolist()

指標計算

y_predy_gtを使って、評価指標を計算してみます。これには、scikit-learnを利用します。まず、評価関数(confusion_matrix, accuracy_score, recall_score, precision_score, f1_score)をインポートします。

from sklearn.metrics import confusion_matrix, accuracy_score, recall_score, precision_score, f1_score

あとは、それぞれの関数を呼び出すだけで指標の計算を行うことができます。

混合行列(confution matrix)

confusion_matrix(y_gt, y_pred)

正解率(Accuracy)、適合率(Precision)、再現率(Recall)、F値(F-measure)

acc, recall, prec, fs = accuracy_score(y_gt, y_pred), recall_score(y_gt, y_pred), precision_score(y_gt, y_pred), f1_score(y_gt, y_pred)
print(f"acc={acc}, recall={recall}, precition = {prec},  f-score={fs}")

それぞれの指標に関しては、以下などが参考になります。

Qiita: 【入門者向け】機械学習の分類問題評価指標解説(正解率・適合率・再現率など)

間違った画像を確認

間違った画像は以下のコードで確認できます。

for idx, (x, y) in enumerate(zip(y_pred, y_gt)):
  if x != y :
    img, label = dataset_valid[idx]
    plt.imshow(img.transpose(1,2,0))
    plt.show()

少ないエポック数でも予想以上に学習が進んでいると感じるかもしれません。これは、pretrained=Trueで事前学習済みモデルを読み込んでいるからです。事前学習モデルは犬と猫の分類を事前に学習しているため、ここでの学習は微調整(特に、ラベルを0と1に変更する部分)が主な作業となっている可能性が大きいです。

まとめ

以上、Pytorch Lightningを使ってクラス分類の学習コードを書いてみました。このコードだとLightningの良さが伝わらなかったかもしれません。Lightningのいいところは、GPU/CPUの切り替えや、32bit/16bitの切り替えなどのメイン以外の部分のコードを書かなくて良くなると言う部分です。今回のものもlogは自動的に保存されていきますし、設定すればtensorboardなどで学習のログを確認することもできるようになります。また、マルチGPUなどへの対応も簡単になります。

また、pytorchのコードから簡単にコンバートすることができます(今回もLightningの機能を使い切っているわけではなく、一部だけポートした感じです)。

かなり便利なので使ってみてください。

メインのコードに集中し、周辺コードを書かなくて良くなると言うのがPytorch Lightningのメリット

最初はpytorchで書いて、パラメータ探索などの試行錯誤をするときに、Lightningに持っていくって言うことも結構多いね

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

ABOUT ME
ある/Aru
ある/Aru
IT&機械学習エンジニア/ファイナンシャルプランナー(CFP®)
専門分野は並列処理・画像処理・機械学習・ディープラーニング。プログラミング言語はC, C++, Go, Pythonを中心として色々利用。現在は、Kaggle, 競プロなどをしながら悠々自適に活動中
記事URLをコピーしました