検出枠に対応!torchvisionのデータ拡張(v2)の使い方を解説
torchvisionのtransforms.v2は、データ拡張(データオーグメンテーション)に物体検出に必要な検出枠(bounding box)やセグメンテーションマスク(mask)のサポートが追加されています。この記事では、transforms.v2
を物体検出タスクで利用する方法について詳しく解説します。
torchvision.transforms.v2とは
2023年10月5日にTorchVision 0.16が公開され、transforms.v2のドキュメントも充実してきました。現在はまだベータ版ですが、今後主流となる可能性が高いため、新しく学習コードを書く際にはこのバージョンを使用した方がよいかもしれません。特に、物体検出タスクに対応したデータ拡張機能が大きく、こちらを利用する機会も増えそうです。
v2の特徴
リリースノートを見ると以下のようになっています。
- 高速化(10~40%)、mps対応
- CutMixとMixUpのサポート
- セグメンテーションマスク、物体検出のBBox、動画のサポート
個人的には、物体検知のバウンディングボックス(BBOX)のサポートと、セグメンテーションマスクのサポートの追加が嬉しいです。
この追加により、YOLOなどの物体検出タスクのデータオーグメンテーション(データ拡張)でも、torchvisionが使えるようになりました。
V2のリファレンス:TRANSFORMING AND AUGMENTING IMAGES
torchvision.transformsから移行する場合
これまで、torchvision.transformsを使っていたコードをv2に修正する場合は、transformsの後ろに.v2
をつけ加えるだけでOKです。
仮に、以下のように宣言して使っていた場合は、変更はインポートだけですみます。
from torchvision import transforms
上記のインポートを以下のように修正します。これでv2に切り替わります。
import torchvision.transforms.v2 as transforms
今回は移行ではないので、以下のように宣言します
from torchvision.transforms import v2
v2
と書くだけの方が楽です。
物体検出タスクでの利用方法
この記事では、物体検出タスクでの使い方を中心に説明したいと思います。
ヘルパー関数
この記事では実際いtransforms.v2
を動かして動作確認していきますが、結果の表示にはヘルパー関数が必要です。ここでは、公式のgithubにあるヘルパー関数をそのまま流用します。
この関数は、画像とバウンディングボックスやマスクをリスト形式で受け取って、それを並べて表示するものです。
torchvisionのutilsが使われています。utils
の使い方については、こちらの記事を参照してください。
# from https://github.com/pytorch/vision/blob/main/gallery/transforms/helpers.py
import matplotlib.pyplot as plt
import torch
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F
def plot(imgs, row_title=None, **imshow_kwargs):
if not isinstance(imgs[0], list):
# Make a 2d grid even if there's just 1 row
imgs = [imgs]
num_rows = len(imgs)
num_cols = len(imgs[0])
_, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
for row_idx, row in enumerate(imgs):
for col_idx, img in enumerate(row):
boxes = None
masks = None
if isinstance(img, tuple):
img, target = img
if isinstance(target, dict):
boxes = target.get("boxes")
masks = target.get("masks")
elif isinstance(target, tv_tensors.BoundingBoxes):
boxes = target
else:
raise ValueError(f"Unexpected target type: {type(target)}")
img = F.to_image(img)
if img.dtype.is_floating_point and img.min() < 0:
# Poor man's re-normalization for the colors to be OK-ish. This
# is useful for images coming out of Normalize()
img -= img.min()
img /= img.max()
img = F.to_dtype(img, torch.uint8, scale=True)
if boxes is not None:
img = draw_bounding_boxes(img, boxes, colors="yellow", width=3)
if masks is not None:
img = draw_segmentation_masks(img, masks.to(torch.bool), colors=["green"] * masks.shape[0], alpha=.65)
ax = axs[row_idx, col_idx]
ax.imshow(img.permute(1, 2, 0).numpy(), **imshow_kwargs)
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if row_title is not None:
for row_idx in range(num_rows):
axs[row_idx, 0].set(ylabel=row_title[row_idx])
plt.tight_layout()
パッケージのインポート
必要なパッケージをインポートします。今回試すv2
以外に、画像読み込みの関数もインポートしておきます。
import torch
import matplotlib.pyplot as plt
from torchvision.transforms import v2
from torchvision.io import read_image, ImageReadMode
画像の読み込み
今回利用したcat2.png
は、ここからダウンロードできます。Jupyter notebookやGoogle Colabで試す場合は、以下のコードを実行することでダウンロードされます。
!wget -O cat2.png https://github.com/aruaru0/SAM-TEST/blob/main/cat2.png?raw=true
画像が準備できたら、以下のコードで読み込みます。
img = read_image('cat2.png', ImageReadMode.RGB)
検出枠が無いデータ拡張(クラス分類などの場合)
クラス分類などで利用する場合は、使い方はこれまでのtorchvision.transforms
と同じです。具体的には、Compose
を使って変換を列挙していきます。
変換は、Compose
を使って生成したオブジェクト(transform
)に画像を渡すことで列挙した処理が順番に実行され、結果が戻ります。
以下のコードでは、ランダムに、「リサイズ」、「切り抜き」、「フリップ」が行われ、画素値が標準化されて出力されます。
transforms = v2.Compose([
v2.RandomResizedCrop(size=(224, 224), antialias=True),
v2.RandomHorizontalFlip(p=0.5),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
out = transforms(img)
plot([img, out])
左側が元の画像、右側がtransform
された画像です。
ランダムに切り抜かれているので、実行するたびに右側の画像は変化します。
検出枠(BBOX)がある場合のデータ拡張
BBOXがある場合も、transform
の定義は同じですが、BBOXを入力する必要があります。以下BBOXの設定方法について解説します。
BBOXの設定
バウンディングボックス(BBOX)は、tv_tensors.BoundingBoxes
で定義します。
バウンディングボックス(bbox)のデータは、tv_tensors.BoundingBoxes
で渡すことになります。座標は、画像サイズに合わせたものです。フォーマットは3つから選ぶことができますが、座標はピクセルです。
最初の引数が、BBOXの座標のデータです。この座標データの形式は、formatで指定します。下のプログラムではフォーマットが”XYXY
“となっていますので、座標データは始点と終点の座標となります。
なお、format
には、以下の3つが指定できます。
format
に指定できる文字列XYXY
:始点と終点XYWH
:始点と幅・高さCXCYWH
:中心と幅・高さ
引数canvas_size
は画像のサイズです。
from torchvision import tv_tensors
boxes = tv_tensors.BoundingBoxes(
[
[40, 50, 380, 640],
[80, 550, 512, 768],
[100, 80, 350, 320]
],
format="XYXY", canvas_size=img.shape[-2:])
plot([(img, boxes)])
上のプログラムで指定したBBOXの場所を可視化すると、以下のようになります。
それぞれ、「猫全体」、「本」、「猫の顔」の枠です。
bboxのフォーマットがYOLOとは異なるので、YOLOで使う場合は、変換を入れる必要がある部分に注意が必要です。
YOLOで使うデータ拡張をしたい場合は、albumentationの方が楽かもしれません。
画像とBBOXだけを入力して処理する場合
transforms(img, boxes)
と、画像とBBOXを入力することで、変換が行われます。
transforms = v2.Compose([
v2.RandomResizedCrop(size=(224, 224), antialias=True),
v2.RandomPhotometricDistort(p=1),
v2.RandomHorizontalFlip(p=1),
])
out_img, out_boxes = transforms(img, boxes)
print(type(boxes), type(out_boxes))
plot([(img, boxes), (out_img, out_boxes)])
左が元画像、右が変換後の画像です。リサイズと切り抜きに合わせてBBOXの座標が修正されていることがわかります。
画像と辞書型で入力する場合
ラベルなどのbbox以外の情報もtransform
に渡すことが可能です。
この場合は、辞書型で入力します。例では、bboxとlabelを入力しています。
例では、Composeで行う変換に、SanitizeBoundingBoxes
を追加しています。これを入れることで、変換によりmin_size
よりも小さくなってしまった枠や、はみ出してしまった枠が出力から除外されます。
transforms = v2.Compose([
v2.RandomResizedCrop(size=(224, 224), antialias=True),
v2.RandomPhotometricDistort(p=1),
v2.RandomHorizontalFlip(p=1),
v2.SanitizeBoundingBoxes(min_size=1),
])
target = {
"boxes": boxes,
"labels": torch.arange(boxes.shape[0]),
}
# Re-using the transforms and definitions from above.
out_img, out_target = transforms(img, target)
plot([(img, target["boxes"]), (out_img, out_target["boxes"])])
out_target
{'boxes': BoundingBoxes([[ 56, 0, 224, 177],
[ 0, 126, 224, 224]], format=BoundingBoxFormat.XYXY, canvas_size=(224, 224)),
'labels': tensor([0, 1])}
実行結果をみると、猫の顔の枠が画像外にはみ出してしまっています。
このため、BBOXが1つ削除されて、返ってきたBBOXは2つになっています。
labels
を見ると[0,1,2]というラベルが付けられていたものが、[0,1]だけとなり、2(=猫の顔)が消えていることがわかります。
このように、領域外にはみ出してしまった枠を削除することも可能です。
枠に合わせてラベルも消えるます。これを使えば枠とクラスIDが削除により矛盾することはありません
おわりに
torchvisionのtransforms.v2について紹介しました。どのフレームワークもどんどん改良が進んで便利になってきています。
高速化に関して、mpsもサポートしたということなのでM1/M2/M3を搭載したMacでも高速に動作することが期待できます。
データ拡張は、学習時のデータ読み込みで実行されるので、高速化されると学習時間が短くなるメリットがあります