機械学習
記事内に商品プロモーションを含む場合があります

物体検出をサポートしたデータ拡張の使い方【torchvision.transofrms.v2】

tadanori

torchvisionのtransforms.v2では、データ拡張(データオーグメンテーション)に物体検出の枠(bounding box)やセグメンテーションマスクをサポートしています。ここでは、transforms.v2について解説します。

torchvision.transforms.v2とは

2023年10月5日に公開されたTorchVision 0.16が公開され、transforms.v2のドキュメントなども充実しました。まだベータ版ですが、こちらのバージョンが主流になるようですので、新しく学習コードを書く場合は、こちらを使うようにした方が良いかもしれません。

v2の特徴

リリースノートを見ると以下のようになっています。

  • 高速化(10~40%)、mps対応
  • CutMixとMixUpのサポート
  • セグメンテーションマスク、物体検出のBBox、動画のサポート

個人的には、物体検知のバウンディングボックス(BBOX)のサポートと、セグメンテーションマスクのサポートの追加が嬉しいです。

この追加により、YOLOなどの物体検出タスクのデータオーグメンテーション(データ拡張)でも、torchvisionが使えるようになりました。

V2のリファレンス:TRANSFORMING AND AUGMENTING IMAGES

torchvision.transformsの切り替え

これまで、torchvision.transformsを使っていた場合は、transformsの後ろに.v2 をつけ加えるだけでOKです。

もし、以下のように宣言して使っていた場合は、変更はインポートだけですみます。

from torchvision import transforms

上記のように書いていたのを、以下のように直すだけです。

import torchvision.transforms.v2 as transforms

今回は移行ではないので、以下のように宣言します

from torchvision.transforms import v2

v2と書くだけの方が楽です。

物体検出の場合の使い方

この記事では、物体検出タスクでの使い方を中心に説明したいと思います。

ヘルパー関数

公式のgithubにあるヘルパー関数をそのまま流用します。

この関数は、画像とバウンディングボックスやマスクをリスト形式で受け取って、それを並べて表示するものです。

torchvisionのutilsが使われています。utilsの使い方については、こちらの記事を参照してください。

PyTorchの可視化機能を紹介|物体検出枠とセグメンテーション結果を可視化
PyTorchの可視化機能を紹介|物体検出枠とセグメンテーション結果を可視化
# from https://github.com/pytorch/vision/blob/main/gallery/transforms/helpers.py
import matplotlib.pyplot as plt
import torch
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F


def plot(imgs, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0])
    _, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        for col_idx, img in enumerate(row):
            boxes = None
            masks = None
            if isinstance(img, tuple):
                img, target = img
                if isinstance(target, dict):
                    boxes = target.get("boxes")
                    masks = target.get("masks")
                elif isinstance(target, tv_tensors.BoundingBoxes):
                    boxes = target
                else:
                    raise ValueError(f"Unexpected target type: {type(target)}")
            img = F.to_image(img)
            if img.dtype.is_floating_point and img.min() < 0:
                # Poor man's re-normalization for the colors to be OK-ish. This
                # is useful for images coming out of Normalize()
                img -= img.min()
                img /= img.max()

            img = F.to_dtype(img, torch.uint8, scale=True)
            if boxes is not None:
                img = draw_bounding_boxes(img, boxes, colors="yellow", width=3)
            if masks is not None:
                img = draw_segmentation_masks(img, masks.to(torch.bool), colors=["green"] * masks.shape[0], alpha=.65)

            ax = axs[row_idx, col_idx]
            ax.imshow(img.permute(1, 2, 0).numpy(), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()

パッケージのインポート

必要なパッケージをインポートします。今回試すv2以外に、画像読み込みの関数もインポートしておきます。

import torch
import matplotlib.pyplot as plt
from torchvision.transforms import v2
from torchvision.io import read_image, ImageReadMode

画像の読み込み

今回利用したcat2.pngは、ここからダウンロードできます。Jupyter notebookやGoogle Colabで試す場合は、以下のコードを実行することでダウンロードされます。

!wget -O cat2.png https://github.com/aruaru0/SAM-TEST/blob/main/cat2.png?raw=true 

画像が準備できたら、以下のコードで読み込みます。

img = read_image('cat2.png', ImageReadMode.RGB)

検出枠なしのデータ拡張

クラス分類などで利用する場合は、使い方は基本的にこれまでのtorchvision.transformsと同じで、Composeを使って変換を列挙していきます。

変換は、Composeを使って生成したオブジェクト(transform)に画像を渡すことで列挙した処理が順番に実行され、結果が戻ります。

以下のコードでは、ランダムに、「リサイズ」、「切り抜き」、「フリップ」が行われ、画素値が標準化されて出力されます。

transforms = v2.Compose([
    v2.RandomResizedCrop(size=(224, 224), antialias=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
out = transforms(img)

plot([img, out])

左側が元の画像、右側がtransformされた画像です。

ランダムに切り抜かれているので、実行するたびに右側の画像は変化します。

実行例

検出枠(BBOX)がある場合のデータ拡張

BBOXがある場合も、transformの定義は同じです。

入力画像のBBOXの設定方法は以下の通りです。

BBOXの設定

バウンディングボックス(BBOX)は、tv_tensors.BoundingBoxesで定義します。

バウンディングボックスのデータフォーマットについて

バウンディングボックス(bbox)のデータは、tv_tensors.BoundingBoxesで渡すことになります。座標は、画像サイズに合わせたものです。フォーマットは3つから選ぶことができますが、座標はピクセルです。

最初の引数が、BBOXの座標のデータです。この座標データの形式は、formatで指定します。下のプログラムではフォーマットが”XYXY“となっていますので、座標データは始点と終点の座標となります。

なお、formatには、以下の3つが指定できます。

formatに指定できる文字列
  • XYXY :始点と終点
  • XYWH:始点と幅・高さ
  • CXCYWH:中心と幅・高さ

引数canvas_sizeは画像のサイズです。

from torchvision import tv_tensors 

boxes = tv_tensors.BoundingBoxes(
    [
        [40, 50, 380, 640],
        [80, 550, 512, 768],
        [100, 80, 350, 320]
    ],
    format="XYXY", canvas_size=img.shape[-2:])

plot([(img, boxes)])

上のプログラムで指定したBBOXの場所を可視化すると、以下のようになります。

それぞれ、「猫全体」、「本」、「猫の顔」の枠です。

BBOXの位置

bboxのフォーマットがYOLOとは異なるので、YOLOで使う場合は、変換を入れる必要がある部分に注意が必要です。

YOLOで使うデータ拡張をしたい場合は、albumentationの方が楽かもしれません。

物体検出でも使えるAlbumentationsの使い方(画像データ拡張, Data Augmentation)
物体検出でも使えるAlbumentationsの使い方(画像データ拡張, Data Augmentation)

画像とbboxだけを入力して処理する場合

transforms(img, boxes)と、画像とbboxを入力することで、変換が行われます。

transforms = v2.Compose([
    v2.RandomResizedCrop(size=(224, 224), antialias=True),
    v2.RandomPhotometricDistort(p=1),
    v2.RandomHorizontalFlip(p=1),
])
out_img, out_boxes = transforms(img, boxes)
print(type(boxes), type(out_boxes))

plot([(img, boxes), (out_img, out_boxes)])

左が元画像、右が変換後の画像です。リサイズと切り抜きに合わせてBBOXの座標が修正されていることがわかります。

枠付きのオーグメンテーションの例

画像と辞書型で入力する場合

ラベルなどのbbox以外の情報もtransformに渡すことが可能です。

この場合は、辞書型で入力します。例では、bboxlabelを入力しています。

例では、Composeで行う変換に、SanitizeBoundingBoxesを追加しています。これを入れることで、変換によりmin_sizeよりも小さくなってしまった枠や、はみ出してしまった枠が出力から除外されます。

transforms = v2.Compose([
    v2.RandomResizedCrop(size=(224, 224), antialias=True),
    v2.RandomPhotometricDistort(p=1),
    v2.RandomHorizontalFlip(p=1),
    v2.SanitizeBoundingBoxes(min_size=1),
])

target = {
    "boxes": boxes,
    "labels": torch.arange(boxes.shape[0]),
}

# Re-using the transforms and definitions from above.
out_img, out_target = transforms(img, target)

plot([(img, target["boxes"]), (out_img, out_target["boxes"])])
out_target

下記の実行結果では、猫の顔の枠が画像外にはみ出してしまっています。

このため、BBOXが1つ削除されて、返ってきたBBOXは2つになっています。

labelsを見ると[0,1,2]というラベルが付けられていたものが、[0,1]だけとなり、2(=猫の顔)が消えていることがわかります。

このように、領域外にはみ出してしまった枠を削除することも可能です。

枠に合わせてラベルも消えるます。これを使えば枠とクラスIDが削除により矛盾することはありません

出力結果
{'boxes': BoundingBoxes([[ 56,   0, 224, 177],
                [  0, 126, 224, 224]], format=BoundingBoxFormat.XYXY, canvas_size=(224, 224)),
 'labels': tensor([0, 1])}

おわりに

torchvisionのtransforms.v2について紹介しました。どのフレームワークもどんどん改良が進んで便利になってきています。

高速化に関して、mpsもサポートしたということなのでM1/M2/M3を搭載したMacでも高速に動作することが期待できます。

データ拡張は、学習時のデータ読み込みで実行されるので、高速化されると学習時間が短くなるメリットがあります

おすすめ書籍

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

記事URLをコピーしました