機械学習
記事内に商品プロモーションを含む場合があります

検出枠に対応!torchvisionのデータ拡張(v2)の使い方を解説

Aru

torchvisionのtransforms.v2は、データ拡張(データオーグメンテーション)に物体検出に必要な検出枠(bounding box)やセグメンテーションマスク(mask)のサポートが追加されています。この記事では、transforms.v2物体検出タスクで利用する方法について詳しく解説します。

torchvision.transforms.v2とは

2023年10月5日にTorchVision 0.16が公開され、transforms.v2のドキュメントも充実してきました。現在はまだベータ版ですが、今後主流となる可能性が高いため、新しく学習コードを書く際にはこのバージョンを使用した方がよいかもしれません。特に、物体検出タスクに対応したデータ拡張機能が大きく、こちらを利用する機会も増えそうです。

v2の特徴

リリースノートを見ると以下のようになっています。

  • 高速化(10~40%)、mps対応
  • CutMixとMixUpのサポート
  • セグメンテーションマスク、物体検出のBBox、動画のサポート

個人的には、物体検知のバウンディングボックス(BBOX)のサポートと、セグメンテーションマスクのサポートの追加が嬉しいです。

この追加により、YOLOなどの物体検出タスクのデータオーグメンテーション(データ拡張)でも、torchvisionが使えるようになりました

V2のリファレンス:TRANSFORMING AND AUGMENTING IMAGES

torchvision.transformsから移行する場合

これまで、torchvision.transformsを使っていたコードをv2に修正する場合は、transformsの後ろに.v2 をつけ加えるだけでOKです。

仮に、以下のように宣言して使っていた場合は、変更はインポートだけですみます。

from torchvision import transforms

上記のインポートを以下のように修正します。これでv2に切り替わります。

import torchvision.transforms.v2 as transforms

今回は移行ではないので、以下のように宣言します

from torchvision.transforms import v2

v2と書くだけの方が楽です。

物体検出タスクでの利用方法

この記事では、物体検出タスクでの使い方を中心に説明したいと思います。

ヘルパー関数

この記事では実際いtransforms.v2を動かして動作確認していきますが、結果の表示にはヘルパー関数が必要です。ここでは、公式のgithubにあるヘルパー関数をそのまま流用します。

この関数は、画像とバウンディングボックスやマスクをリスト形式で受け取って、それを並べて表示するものです。

torchvisionのutilsが使われています。utilsの使い方については、こちらの記事を参照してください。

PyTorchの可視化機能を紹介|物体検出枠とセグメンテーション結果を可視化
PyTorchの可視化機能を紹介|物体検出枠とセグメンテーション結果を可視化
# from https://github.com/pytorch/vision/blob/main/gallery/transforms/helpers.py
import matplotlib.pyplot as plt
import torch
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F


def plot(imgs, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0])
    _, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        for col_idx, img in enumerate(row):
            boxes = None
            masks = None
            if isinstance(img, tuple):
                img, target = img
                if isinstance(target, dict):
                    boxes = target.get("boxes")
                    masks = target.get("masks")
                elif isinstance(target, tv_tensors.BoundingBoxes):
                    boxes = target
                else:
                    raise ValueError(f"Unexpected target type: {type(target)}")
            img = F.to_image(img)
            if img.dtype.is_floating_point and img.min() < 0:
                # Poor man's re-normalization for the colors to be OK-ish. This
                # is useful for images coming out of Normalize()
                img -= img.min()
                img /= img.max()

            img = F.to_dtype(img, torch.uint8, scale=True)
            if boxes is not None:
                img = draw_bounding_boxes(img, boxes, colors="yellow", width=3)
            if masks is not None:
                img = draw_segmentation_masks(img, masks.to(torch.bool), colors=["green"] * masks.shape[0], alpha=.65)

            ax = axs[row_idx, col_idx]
            ax.imshow(img.permute(1, 2, 0).numpy(), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()

パッケージのインポート

必要なパッケージをインポートします。今回試すv2以外に、画像読み込みの関数もインポートしておきます。

import torch
import matplotlib.pyplot as plt
from torchvision.transforms import v2
from torchvision.io import read_image, ImageReadMode

画像の読み込み

今回利用したcat2.pngは、ここからダウンロードできます。Jupyter notebookやGoogle Colabで試す場合は、以下のコードを実行することでダウンロードされます。

!wget -O cat2.png https://github.com/aruaru0/SAM-TEST/blob/main/cat2.png?raw=true 

画像が準備できたら、以下のコードで読み込みます。

img = read_image('cat2.png', ImageReadMode.RGB)

検出枠が無いデータ拡張(クラス分類などの場合)

クラス分類などで利用する場合は、使い方はこれまでのtorchvision.transformsと同じです。具体的には、Composeを使って変換を列挙していきます。

変換は、Composeを使って生成したオブジェクト(transform)に画像を渡すことで列挙した処理が順番に実行され、結果が戻ります。

以下のコードでは、ランダムに、「リサイズ」、「切り抜き」、「フリップ」が行われ、画素値が標準化されて出力されます。

transforms = v2.Compose([
    v2.RandomResizedCrop(size=(224, 224), antialias=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
out = transforms(img)

plot([img, out])

左側が元の画像、右側がtransformされた画像です。

ランダムに切り抜かれているので、実行するたびに右側の画像は変化します。

実行例

検出枠(BBOX)がある場合のデータ拡張

BBOXがある場合も、transformの定義は同じですが、BBOXを入力する必要があります。以下BBOXの設定方法について解説します。

BBOXの設定

バウンディングボックス(BBOX)は、tv_tensors.BoundingBoxesで定義します。

バウンディングボックスのデータフォーマットについて

バウンディングボックス(bbox)のデータは、tv_tensors.BoundingBoxesで渡すことになります。座標は、画像サイズに合わせたものです。フォーマットは3つから選ぶことができますが、座標はピクセルです。

最初の引数が、BBOXの座標のデータです。この座標データの形式は、formatで指定します。下のプログラムではフォーマットが”XYXY“となっていますので、座標データは始点と終点の座標となります。

なお、formatには、以下の3つが指定できます。

formatに指定できる文字列
  • XYXY :始点と終点
  • XYWH:始点と幅・高さ
  • CXCYWH:中心と幅・高さ

引数canvas_sizeは画像のサイズです。

from torchvision import tv_tensors 

boxes = tv_tensors.BoundingBoxes(
    [
        [40, 50, 380, 640],
        [80, 550, 512, 768],
        [100, 80, 350, 320]
    ],
    format="XYXY", canvas_size=img.shape[-2:])

plot([(img, boxes)])

上のプログラムで指定したBBOXの場所を可視化すると、以下のようになります。

それぞれ、「猫全体」、「本」、「猫の顔」の枠です。

BBOXの位置

bboxのフォーマットがYOLOとは異なるので、YOLOで使う場合は、変換を入れる必要がある部分に注意が必要です。

YOLOで使うデータ拡張をしたい場合は、albumentationの方が楽かもしれません。

Albumentations:物体検出(枠)にも対応したデータ拡張ライブラリを解説
Albumentations:物体検出(枠)にも対応したデータ拡張ライブラリを解説

画像とBBOXだけを入力して処理する場合

transforms(img, boxes)と、画像とBBOXを入力することで、変換が行われます。

transforms = v2.Compose([
    v2.RandomResizedCrop(size=(224, 224), antialias=True),
    v2.RandomPhotometricDistort(p=1),
    v2.RandomHorizontalFlip(p=1),
])
out_img, out_boxes = transforms(img, boxes)
print(type(boxes), type(out_boxes))

plot([(img, boxes), (out_img, out_boxes)])

左が元画像、右が変換後の画像です。リサイズと切り抜きに合わせてBBOXの座標が修正されていることがわかります。

枠付きのオーグメンテーションの例

画像と辞書型で入力する場合

ラベルなどのbbox以外の情報もtransformに渡すことが可能です。

この場合は、辞書型で入力します。例では、bboxlabelを入力しています。

例では、Composeで行う変換に、SanitizeBoundingBoxesを追加しています。これを入れることで、変換によりmin_sizeよりも小さくなってしまった枠や、はみ出してしまった枠が出力から除外されます。

transforms = v2.Compose([
    v2.RandomResizedCrop(size=(224, 224), antialias=True),
    v2.RandomPhotometricDistort(p=1),
    v2.RandomHorizontalFlip(p=1),
    v2.SanitizeBoundingBoxes(min_size=1),
])

target = {
    "boxes": boxes,
    "labels": torch.arange(boxes.shape[0]),
}

# Re-using the transforms and definitions from above.
out_img, out_target = transforms(img, target)

plot([(img, target["boxes"]), (out_img, out_target["boxes"])])
out_target
出力結果
{'boxes': BoundingBoxes([[ 56,   0, 224, 177],
                [  0, 126, 224, 224]], format=BoundingBoxFormat.XYXY, canvas_size=(224, 224)),
 'labels': tensor([0, 1])}

実行結果をみると、猫の顔の枠が画像外にはみ出してしまっています

このため、BBOXが1つ削除されて、返ってきたBBOXは2つになっています。

labelsを見ると[0,1,2]というラベルが付けられていたものが、[0,1]だけとなり、2(=猫の顔)が消えていることがわかります。

このように、領域外にはみ出してしまった枠を削除することも可能です。

枠に合わせてラベルも消えるます。これを使えば枠とクラスIDが削除により矛盾することはありません

おわりに

torchvisionのtransforms.v2について紹介しました。どのフレームワークもどんどん改良が進んで便利になってきています。

高速化に関して、mpsもサポートしたということなのでM1/M2/M3を搭載したMacでも高速に動作することが期待できます。

データ拡張は、学習時のデータ読み込みで実行されるので、高速化されると学習時間が短くなるメリットがあります

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

ABOUT ME
ある/Aru
ある/Aru
IT&機械学習エンジニア/ファイナンシャルプランナー(CFP®)
専門分野は並列処理・画像処理・機械学習・ディープラーニング。プログラミング言語はC, C++, Go, Pythonを中心として色々利用。現在は、Kaggle, 競プロなどをしながら悠々自適に活動中
記事URLをコピーしました