PyTorch-独自データセット(custom dataset)の作り方【初級 深層学習講座】
PyTorchで用意されているDatasetクラスでは対応できない場合、カスタムデータセットを自作する必要があります。この記事では、PyTorchでカスタムデータセットを作成する方法について、実際のコード例とともに詳しく解説します。自作データセットによるデータローダー(dataloader)を使いたい方は、一読ください。
PytorchのDataset, Dataloaderとは
PyTorchでは、データセット(Dataset)、データローダー(Dataloader)を使って機械学習に必要なデータを読み込みます。以下、2つについて簡単に説明します。
データセット(Dataset)
データセットは、データの集合を管理するクラスです。データセットは、インデックスで指定されたデータを返します。PyTorchでは、MNISTやCIFER-10といったデータセットや、一定のフォーマットに従ったデータに対してのデータセットが用意されています。
データの集合を管理し、インデックスに従ってデータの1つを返すクラスがデータセットです
データローダー(Dataloader)
データローダーは、データセットからデータを読み込み、ミニバッチに分割するクラスです。このクラスを利用することで、データセットをミニバッチに分割して読み込むことができます。
データセットは、バッチ単位でデータを読み込むためのクラスです
カスタムデータセットが必要なケース
クラスごとにフォルダが分類された画像データなどは、torchvision
を使えばデータセットとして読み込むことができたり、ある程度のデータフォーマットに対しては標準でデータセットクラスが用意されています。
どのような場合に自作のデータセットが必要になるかというと、入力するデータが独自のフォーマットで格納されていたり、特殊な出力を行う場合などです。このようなケースでは、用意されているデータセットではカバーできません。
この場合は、カスタムデータセットを作成する必要があります。
- 入力データのフォーマットが特殊な場合
- 特殊な出力を行う場合
ここでは、カスタムデータセットを作成する方法について解説します。
カスタムデータセットを使った学習コード
サンプルコードについて
ここでは、MNISTの手書き文字データセットを読み込むカスタムデータセットを作成し、それを使って学習・推論を行うコードを作成します。
解説は、Google Colab / jupyter notebookで動かすことを前提にしていますので、それ以外で動作させる場合は注意してください。
Colabで動作するサンプルコードをこちらに用意しました。
パッケージのインストール
モデルの作成にTIMM(PyTorch Image Models)を利用するので、timmをインストールします。
!pip install timm
※上記はノートブックからインストールする場合の例です。コマンドラインからインストールする場合は、!
は必要ありません
MNISTのデータセットをダウンロード
MINSTの手書き文字データセットをダウンロードします。
ダウンロード
!wget https://pjreddie.com/media/files/mnist_train.tar.gz
!wget https://pjreddie.com/media/files/mnist_test.tar.gz
展開
データセットを./datasetフォルダに展開します。
!mkdir dataset
%cd dataset
!tar -xf ../mnist_train.tar.gz
!tar -xf ../mnist_test.tar.gz
%cd ..
以上でデータセットの準備は完了です。
正しく動作した場合は、dataset
フォルダにtrain
とtest
フォルダが作成されているはずです。
ライブラリのインポート
今回のプログラムで利用するライブラリをインポートします。
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import glob
import multiprocessing
from tqdm.notebook import tqdm
import timm
カスタムデータセットの定義
今回の記事のメインの部分になります。
カスタムデータセット(CustomDataset
)を定義します。カスタムのデータセットを作る場合、torch.utils.data.Dataset
を継承したクラスを定義します。
また、__init__
, __len__
, __getitem__
の3つの関数を定義します。
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, files, mode):
self.files = files
self.mode = mode # mode 0...train, 1...valid, 2...test(without label)
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
file = self.files[idx]
image = Image.open(file)
image = torch.from_numpy(np.array(image).astype(np.float32))
image = image.unsqueeze(0)
if self.mode == 0:
pass # ここに、データ拡張などの処理を書く
if self.mode == 2:
label = None
else:
label = int(file[-5])
return image, label
__init__()
初期化関数です。引数は、自由に設定できます。今回のデータセットでは、手書き文字画像のファイル名を格納したfiles
と、mode
を引数をして受け取っています。mode
は、0,1,2のいずれかで、0の場合は訓練データのデータセット、1の場合は検証データのデータセット、2の場合はラベルのないデータセットとしています。
__init__
にモード切り替えのパラメータ設定を加えておくことで、カスタムデータセットを、訓練、検証等のデータにより動作を切り替えることができます。
__len__()
データセットのデータ数を返すメソッドです。このデータセットに含まれるデータ数を返すようにします。
__getitem__()
データセット内のデータを返すメソッドです。idx
番号に対応するデータを返します。入力はidx
で、戻り値は自由ですが、基本的には、データとラベルなどを返します。
また、戻り値は、torchのテンソルにしておくと楽です。
なお、今回は機械学習のモデルにTIMMを利用するので、timmのモデルの画像の入力フォーマット(ch, h, w)の形式に合わせるためにunsqeeze(0)で次元を追加しています(読み込んだデータは28×28ですが、これを1x28x28のフォーマットに変換しています)。
カスタムデータセットの作成では、上の3つのメソッドを実装します
train, testデータセットを定義
訓練用、検証用のデータセットを作成します。
glob.glob("dataset/train/*")
で、訓練画像のファイル名の一覧が作成されますので、これをCustomDataset
に渡しています。
train_file = glob.glob("dataset/train/*")
test_file = glob.glob("dataset/test/*")
train_dataset = CustomDataset(train_file, mode = 0)
test_dataset = CustomDataset(test_file, mode = 1)
ここで、実際に、画像が読み込めるかのチェックを行います。
img, label = train_dataset[100]
plt.imshow(img[0])
print(label)
上のコードを実行して、以下のような画像が表示されればOKです。
各種パラメータ設定
学習に利用する各種パラメータを定義しておきます。
EPOCHS = 10
TRAIN_BATCH_SIZE = 32
TEST_BATCH_SIZE = 32
cpus = multiprocessing.cpu_count()
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
データローダーの定義
データローダーは、カスタムする必要は少ないです。今回も、用意されたデータローダーを利用します。
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = TRAIN_BATCH_SIZE, shuffle = True, num_workers = cpus)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = TEST_BATCH_SIZE, shuffle = False, num_workers = cpus)
データローダーの動作を確認します。
Iter = iter(train_loader)
imgs, labels = next(Iter)
fix, ax = plt.subplots(4, 8)
for i in range(4):
for j in range(8):
ax[i, j].imshow(imgs[i*8+j][0])
ax[i, j].set_title(int(labels[i*8+j]))
ax[i, j].axis('off')
plt.show()
上記のように出力されればOKです。なお、train_loader
はshuffle=True
になっているので、実行するたびに、出力内容が変化するはずですので試してみてください。
ここまでが、カスタムデータセットの作り方です。一応学習コードなどもつけておきますが、この記事のポイントはここまでです。
学習
ここでは、timmのresnet18を使って学習してみます。create_model
では、データセットに合わせて、チャネル数を1チャネル(in_chans = 1
)、クラス数を10(num_classes = 10
)にしてモデルを生成しています。
TIMMの使い方については、以下の記事を参考にしてください
model = timm.create_model("resnet18", pretrained = False, num_classes = 10, in_chans = 1)
model = model.to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)
以下は学習コードです。訓練→評価をEPOCHS数分だけ繰り返しています。
best_loss = 10**18
for epoch in range (EPOCHS) :
train_loss = 0
train_cnt = 0
model.train()
print(f"EPOCH{epoch}")
print("[Train]")
for imgs, labels in tqdm(train_loader) :
optimizer.zero_grad()
output = model(imgs.to(device))
loss = F.cross_entropy(output, labels.to(device))
loss.backward()
train_loss += loss.cpu().detach().numpy()
train_cnt += 1
optimizer.step()
valid_loss = 0
valid_cnt = 0
model.eval()
print("[Valid]")
for imgs, labels in tqdm(test_loader) :
with torch.no_grad() :
output = model(imgs.to(device))
loss = F.cross_entropy(output, labels.to(device))
valid_loss += loss.cpu().detach().numpy()
valid_cnt += 1
if valid_loss < best_loss :
best_loss = valid_loss
torch.save(model.state_dict(), "best.pt")
print("[Score]")
print(f"train_loss: {train_loss/train_cnt} valid_loss: {valid_loss/valid_cnt}")
動作させると以下のようなログが出力されます(画像はログの一部)
推論と評価
学習したモデルを使って推論します。下記のコードでy_predに予測値が、y_trueに真値が格納されます。
y_pred = []
y_true = []
model.load_state_dict(torch.load("best.pt"))
model.eval()
for imgs, labels in tqdm(test_loader) :
with torch.no_grad() :
output = model(imgs.to(device))
y_pred += (torch.argmax(output, axis=1).tolist())
y_true += labels.tolist()
混同行列を求めてみます。
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, classification_report
cm = confusion_matrix(y_true, y_pred)
for e in cm :
print(e)
結果を見ると若干ミスがあるようです。
precision, recall, f1-scoreなどを求めてみます。
print(classification_report(y_true, y_pred))
出力結果は以下になります。
precision recall f1-score support
0 1.00 0.99 0.99 980
1 0.99 1.00 1.00 1135
2 1.00 0.99 1.00 1032
3 0.99 0.99 0.99 1010
4 1.00 0.99 0.99 982
5 0.99 0.99 0.99 892
6 0.99 0.99 0.99 958
7 0.99 0.99 0.99 1028
8 0.99 0.99 0.99 974
9 0.99 0.98 0.99 1009
accuracy 0.99 10000
macro avg 0.99 0.99 0.99 10000
weighted avg 0.99 0.99 0.99 10000
結果を見ると、すべての文字について0.99以上の正解率になっていることがわかります。
データセットのテンプレート
最後にデータセットのテンプレートを書いておきます。
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, 必要な引数):
# パラメータの設定などを行う
self.size = データ数
self.xxxx = xxxxx
:
def __len__(self):
return len(self.size)
def __getitem__(self, idx):
# idxに対応した、データを返す
# ※torchのテンソルにしておくと楽
return 返すデータ、返すデータ2
カスタムデータセットを作成するのはそれほど難しくありません
私は、後で色々試行錯誤することが多いので、大体の場合カスタムデータセットを作っています
まとめ
PyTorchでカスタムデータセットを作る方法について解説しました。サンプルコードのデータではカスタムデータセットを作る必要はありませんが、このような場合でも、後で拡張しやすいカスタムデータセットを作っておくことをおすすめします。