初級ディープラーニング
記事内に商品プロモーションを含む場合があります

PyTorch-独自データセット(custom dataset)の作り方【初級 深層学習講座】

Aru

PyTorchで用意されているDatasetクラスでは対応できない場合、カスタムデータセットを自作する必要があります。この記事では、PyTorchでカスタムデータセットを作成する方法について、実際のコード例とともに詳しく解説します。自作データセットによるデータローダー(dataloader)を使いたい方は、一読ください。

PytorchのDataset, Dataloaderとは

PyTorchでは、データセット(Dataset)、データローダー(Dataloader)を使って機械学習に必要なデータを読み込みます。以下、2つについて簡単に説明します。

データセット(Dataset)

データセットは、データの集合を管理するクラスです。データセットは、インデックスで指定されたデータを返します。PyTorchでは、MNISTやCIFER-10といったデータセットや、一定のフォーマットに従ったデータに対してのデータセットが用意されています。

データの集合を管理し、インデックスに従ってデータの1つを返すクラスがデータセットです

データローダー(Dataloader)

データローダーは、データセットからデータを読み込み、ミニバッチに分割するクラスです。このクラスを利用することで、データセットをミニバッチに分割して読み込むことができます。

データセットは、バッチ単位でデータを読み込むためのクラスです

カスタムデータセットが必要なケース

クラスごとにフォルダが分類された画像データなどは、torchvisionを使えばデータセットとして読み込むことができたり、ある程度のデータフォーマットに対しては標準でデータセットクラスが用意されています。

どのような場合に自作のデータセットが必要になるかというと、入力するデータが独自のフォーマットで格納されていたり特殊な出力を行う場合などです。このようなケースでは、用意されているデータセットではカバーできません。

この場合は、カスタムデータセットを作成する必要があります。

カスタムデータセットを作成する必要があるケース
  • 入力データのフォーマットが特殊な場合
  • 特殊な出力を行う場合

ここでは、カスタムデータセットを作成する方法について解説します。

カスタムデータセットを使った学習コード

サンプルコードについて

ここでは、MNISTの手書き文字データセットを読み込むカスタムデータセットを作成し、それを使って学習・推論を行うコードを作成します。

解説は、Google Colab / jupyter notebookで動かすことを前提にしていますので、それ以外で動作させる場合は注意してください。

Colabで動作するサンプルコードをこちらに用意しました。

パッケージのインストール

モデルの作成にTIMM(PyTorch Image Models)を利用するので、timmをインストールします。

!pip install timm

※上記はノートブックからインストールする場合の例です。コマンドラインからインストールする場合は、!は必要ありません

MNISTのデータセットをダウンロード

MINSTの手書き文字データセットをダウンロードします。

ダウンロード

!wget https://pjreddie.com/media/files/mnist_train.tar.gz
!wget https://pjreddie.com/media/files/mnist_test.tar.gz

展開

データセットを./datasetフォルダに展開します。

!mkdir dataset
%cd dataset
!tar -xf ../mnist_train.tar.gz 
!tar -xf ../mnist_test.tar.gz
%cd ..

以上でデータセットの準備は完了です。

正しく動作した場合は、datasetフォルダにtraintestフォルダが作成されているはずです。

ライブラリのインポート

今回のプログラムで利用するライブラリをインポートします。

import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import glob
import multiprocessing

from tqdm.notebook import tqdm
import timm

カスタムデータセットの定義

今回の記事のメインの部分になります。

カスタムデータセット(CustomDataset)を定義します。カスタムのデータセットを作る場合、torch.utils.data.Datasetを継承したクラスを定義します。

また、__init__, __len__, __getitem__の3つの関数を定義します。

class CustomDataset(torch.utils.data.Dataset):
  def __init__(self, files, mode):
    self.files = files
    self.mode = mode # mode 0...train, 1...valid, 2...test(without label)

  def __len__(self):
    return len(self.files)

  def __getitem__(self, idx):
    file = self.files[idx]
    image = Image.open(file)
    image = torch.from_numpy(np.array(image).astype(np.float32))
    image = image.unsqueeze(0)

    if self.mode == 0: 
      pass # ここに、データ拡張などの処理を書く
    
    if self.mode == 2:
      label = None
    else:
      label = int(file[-5])


    return image, label

__init__()

初期化関数です。引数は、自由に設定できます。今回のデータセットでは、手書き文字画像のファイル名を格納したfilesと、modeを引数をして受け取っています。modeは、0,1,2のいずれかで、0の場合は訓練データのデータセット、1の場合は検証データのデータセット、2の場合はラベルのないデータセットとしています。

__init__モード切り替えのパラメータ設定を加えておくことで、カスタムデータセットを、訓練、検証等のデータにより動作を切り替えることができます

__len__()

データセットのデータ数を返すメソッドです。このデータセットに含まれるデータ数を返すようにします。

__getitem__()

データセット内のデータを返すメソッドです。idx番号に対応するデータを返します。入力はidxで、戻り値は自由ですが、基本的には、データとラベルなどを返します。

また、戻り値は、torchのテンソルにしておくと楽です。

なお、今回は機械学習のモデルにTIMMを利用するので、timmのモデルの画像の入力フォーマット(ch, h, w)の形式に合わせるためにunsqeeze(0)で次元を追加しています(読み込んだデータは28×28ですが、これを1x28x28のフォーマットに変換しています)。

カスタムデータセットの作成では、上の3つのメソッドを実装します

train, testデータセットを定義

訓練用、検証用のデータセットを作成します。

glob.glob("dataset/train/*")で、訓練画像のファイル名の一覧が作成されますので、これをCustomDatasetに渡しています。

train_file = glob.glob("dataset/train/*")
test_file = glob.glob("dataset/test/*")

train_dataset = CustomDataset(train_file, mode = 0)
test_dataset = CustomDataset(test_file, mode = 1)

ここで、実際に、画像が読み込めるかのチェックを行います。

img, label = train_dataset[100]
plt.imshow(img[0])
print(label)

上のコードを実行して、以下のような画像が表示されればOKです。

データローダーの画像出力

各種パラメータ設定

学習に利用する各種パラメータを定義しておきます。

EPOCHS = 10
TRAIN_BATCH_SIZE = 32
TEST_BATCH_SIZE = 32
cpus = multiprocessing.cpu_count()

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

データローダーの定義

データローダーは、カスタムする必要は少ないです。今回も、用意されたデータローダーを利用します。

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = TRAIN_BATCH_SIZE, shuffle = True, num_workers = cpus)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = TEST_BATCH_SIZE, shuffle = False, num_workers = cpus)

データローダーの動作を確認します。

Iter = iter(train_loader)
imgs, labels = next(Iter)

fix, ax = plt.subplots(4, 8)
for i in range(4):
  for j in range(8):
    ax[i, j].imshow(imgs[i*8+j][0])
    ax[i, j].set_title(int(labels[i*8+j]))
    ax[i, j].axis('off')
plt.show()
データローダーの動作確認

上記のように出力されればOKです。なお、train_loadershuffle=Trueになっているので、実行するたびに、出力内容が変化するはずですので試してみてください。

ここまでが、カスタムデータセットの作り方です。一応学習コードなどもつけておきますが、この記事のポイントはここまでです。

学習

ここでは、timmのresnet18を使って学習してみます。create_modelでは、データセットに合わせて、チャネル数を1チャネル(in_chans = 1)、クラス数を10(num_classes = 10)にしてモデルを生成しています。

model = timm.create_model("resnet18", pretrained = False, num_classes = 10, in_chans = 1)
model = model.to(device)

optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)

以下は学習コードです。訓練→評価をEPOCHS数分だけ繰り返しています。

best_loss = 10**18

for epoch in range (EPOCHS) :
  train_loss = 0
  train_cnt = 0
  model.train()
  print(f"EPOCH{epoch}")
  print("[Train]")
  for imgs, labels in tqdm(train_loader) :
    optimizer.zero_grad()
    output = model(imgs.to(device))
    loss = F.cross_entropy(output, labels.to(device))
    loss.backward()
    train_loss += loss.cpu().detach().numpy()
    train_cnt += 1
    optimizer.step()

  valid_loss = 0
  valid_cnt = 0
  model.eval()
  print("[Valid]")
  for imgs, labels in tqdm(test_loader) :
    with torch.no_grad() :
      output = model(imgs.to(device))
    loss = F.cross_entropy(output, labels.to(device))
    valid_loss += loss.cpu().detach().numpy()
    valid_cnt += 1

  if valid_loss < best_loss :
    best_loss = valid_loss
    torch.save(model.state_dict(), "best.pt")

  print("[Score]")
  print(f"train_loss: {train_loss/train_cnt} valid_loss: {valid_loss/valid_cnt}")

動作させると以下のようなログが出力されます(画像はログの一部)

出力ログの一部

推論と評価

学習したモデルを使って推論します。下記のコードでy_predに予測値が、y_trueに真値が格納されます。

y_pred = []
y_true = []

model.load_state_dict(torch.load("best.pt"))
model.eval()

for imgs, labels in tqdm(test_loader) :
  with torch.no_grad() :
    output = model(imgs.to(device))
  y_pred += (torch.argmax(output, axis=1).tolist())
  y_true += labels.tolist()

混同行列を求めてみます。

from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, classification_report

cm = confusion_matrix(y_true, y_pred)
for e in cm :
  print(e)

結果を見ると若干ミスがあるようです。

混同行列

precision, recall, f1-scoreなどを求めてみます。

print(classification_report(y_true, y_pred)) 

出力結果は以下になります。

             precision    recall  f1-score   support

           0       1.00      0.99      0.99       980
           1       0.99      1.00      1.00      1135
           2       1.00      0.99      1.00      1032
           3       0.99      0.99      0.99      1010
           4       1.00      0.99      0.99       982
           5       0.99      0.99      0.99       892
           6       0.99      0.99      0.99       958
           7       0.99      0.99      0.99      1028
           8       0.99      0.99      0.99       974
           9       0.99      0.98      0.99      1009

    accuracy                           0.99     10000
   macro avg       0.99      0.99      0.99     10000
weighted avg       0.99      0.99      0.99     10000

結果を見ると、すべての文字について0.99以上の正解率になっていることがわかります。

データセットのテンプレート

最後にデータセットのテンプレートを書いておきます。

class CustomDataset(torch.utils.data.Dataset):
  def __init__(self, 必要な引数):
    # パラメータの設定などを行う
    self.size = データ数
    self.xxxx = xxxxx
        :

  def __len__(self):
    return len(self.size)

  def __getitem__(self, idx):
    # idxに対応した、データを返す
    # ※torchのテンソルにしておくと楽

    return 返すデータ、返すデータ2

カスタムデータセットを作成するのはそれほど難しくありません

私は、後で色々試行錯誤することが多いので、大体の場合カスタムデータセットを作っています

 

まとめ

PyTorchでカスタムデータセットを作る方法について解説しました。サンプルコードのデータではカスタムデータセットを作る必要はありませんが、このような場合でも、後で拡張しやすいカスタムデータセットを作っておくことをおすすめします。

初級 深層学習講座シリーズはこちら
ディープラーニングに関する記事一覧はこちら
ディープラーニング関連の記事一覧
ディープラーニング関連の記事一覧

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

ABOUT ME
ある/Aru
ある/Aru
IT&機械学習エンジニア/ファイナンシャルプランナー(CFP®)
専門分野は並列処理・画像処理・機械学習・ディープラーニング。プログラミング言語はC, C++, Go, Pythonを中心として色々利用。現在は、Kaggle, 競プロなどをしながら悠々自適に活動中
記事URLをコピーしました