機械学習
記事内に商品プロモーションを含む場合があります

SegFormerでの学習・推論方法を解説 | セグメンテーション向けのtransformerを使う

tadanori

SefFormerはセグメンテーションタスク向けのtransformerです。この記事では、SegFormerで独自データを学習する方法、および、推論手順について解説しています。記事中のコードをGoogle Colabで実行させながら動作確認を行うことができます。

YOLOv8でセグメンテーションを行う記事も参考にしてください

あわせて読みたい
YOLOv8でセグメンテーション|学習と推論を実践
YOLOv8でセグメンテーション|学習と推論を実践

SegFormerとは

SegFormerは、セグメンテーションのタスクを行うためのモデルです。セグメンテーションは、画像の各ピクセルにクラスを割り当てるもので、画像中に存在する人やオブジェクトを画素毎にクラス分けするタスクになります。

SegFormerは、このタスクにTransformerを応用したもので、既存のモデルと比較して計算コストを抑えて制度を向上したモデルになります。

論文(arXiv):https://arxiv.org/abs/2105.15203
GitHub: https://github.com/NVlabs/SegFormer

今回は、SegFormerを使って学習・推論させてみます。

SegFormerで学習させてみる

コードはGoogle Colabでテストできます。各コードをコードセルに貼り付けながら実行していってください。

SegfFrmerのインストール

SegFormerはHugging Faceで提供されていますので、今回はそれを利用します。利用するには以下のパッケージのインストールが必要です。

!pip install transformers
!pip install datasets
!pip install evaluate
!pip install git+https://github.com/huggingface/accelerate
重要

重要:インストールが終わったら、Google Colabの「ランタイム」→「ランタイムを再起動」を行ってリセットしてください。

再起動しないと、エラーが発生します。

データセットの読み込み

データセットは、HuggingFaceで準備されているデータセットを利用します。

split="train[:xx]"のようにオプションを追加すると、データの一部だけをロードできます。今回は、チュートリアルに習って50個だけを読み込むことにしました。

ちなみに、あとで書きますが50個のデータではちゃんと学習できません。すべてのデータを利用すると、Google ColabのT4(GPU)で35時間と表示されました。今回は動作確認のために小さなデータで試しています。

ダウンロードはすべて行われているようです。

from datasets import load_dataset
# データセットをロード。高速化のために一部だけを利用
ds = load_dataset("scene_parse_150", split="train[:50]")
ds

データをtrainとtestデータに分割します

ds = ds.train_test_split(test_size=0.2)
train_ds = ds["train"]
test_ds = ds["test"]

train_dsの中身を確認します。Colabでは以下のコードでデータを確認できます。

train_ds[0]["image"]
訓練画像

アノテーションデータも確認しておきます。アノテーションデータは、以下のようにすれば画像としてみることができます(普通に画像として表示させると真っ黒になります)。

import matplotlib.pyplot as plt
plt.imshow(train_ds[0]["annotation"])
アノテーションデータ

アノテーションデータには、150種類のクラスがありますが、これを辞書型にいれておきます。ここで作成した辞書id2label, label2idは、モデルのオプションに指定していますが、これは行わなくても良いです。

# 150種類のラベルのラベル名とIDの対応
import json
from huggingface_hub import cached_download, hf_hub_url

repo_id = "huggingface/label-files"
filename = "ade20k-id2label.json"
id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename, repo_type="dataset")), "r"))
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}
num_labels = len(id2label)

プリプロセス(データ拡張)

次に画像の事前処理を行うimage_processorを定義します。

データセットのクラスが1始まりなので、0始まりに直す必要がありますが、これはreduce_labels=Trueで行うことができます。checkpointは利用するモデルと合わせます。今回はnvidia/mit-b0を利用しますので、これを設定しています。

from transformers import AutoImageProcessor

checkpoint = "nvidia/mit-b0"
image_processor = AutoImageProcessor.from_pretrained(checkpoint, reduce_labels=True)

学習時にデータ拡張として、画像にジッターを付加します。データ拡張を行うことで学習効果を高めることができます。まず、ジッターを付加する関数を定義します。

from torchvision.transforms import ColorJitter

jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)

次に、データを変換する関数を定義します。処理的には、入力された画像とアノテーションデータに対して、image_processorを通す諸rになります。trainはジッターを付加しますが、testはジッターを付加しません。

def train_transforms(example_batch):
    images = [jitter(x) for x in example_batch["image"]]
    labels = [x for x in example_batch["annotation"]]
    inputs = image_processor(images, labels)
    return inputs


def val_transforms(example_batch):
    images = [x for x in example_batch["image"]]
    labels = [x for x in example_batch["annotation"]]
    inputs = image_processor(images, labels)
    return inputs

これをデータセットの変換関数として設定してデータセットの定義は終了です。

train_ds.set_transform(train_transforms)
test_ds.set_transform(val_transforms)

学習(トレーニング)

メトリックの定義

評価メトリックを定義します。これも用意されているものを利用します。

import evaluate

metric = evaluate.load("mean_iou")

定義した、compute_metrics関数はTrainerの引数で与えます。

import numpy as np
import torch
from torch import nn

def compute_metrics(eval_pred):
    with torch.no_grad():
        logits, labels = eval_pred
        logits_tensor = torch.from_numpy(logits)
        logits_tensor = nn.functional.interpolate(
            logits_tensor,
            size=labels.shape[-2:],
            mode="bilinear",
            align_corners=False,
        ).argmax(dim=1)

        pred_labels = logits_tensor.detach().cpu().numpy()
        metrics = metric.compute(
            predictions=pred_labels,
            references=labels,
            num_labels=num_labels,
            ignore_index=255,
            reduce_labels=False,
        )
        for key, value in metrics.items():
            if type(value) is np.ndarray:
                metrics[key] = value.tolist()
        return metrics

モデルの読み込み

モデルの読み込みは以下のようになります。

from transformers import AutoModelForSemanticSegmentation, TrainingArguments, Trainer

model = AutoModelForSemanticSegmentation.from_pretrained(checkpoint, id2label=id2label, label2id=label2id)

学習

まず、TrainingArgumentsで学習のパラメータを設定します。出力が大きいのでバッチサイズは小さくしないとメモリが不足しますので注意してください。

最後にTrainertrain_dstest_dsと、compute_metricsを指定し、trainer.train()で学習を進めます。

training_args = TrainingArguments(
    output_dir="./", 
    learning_rate=6e-5,
    num_train_epochs=50,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    save_total_limit=3,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=20,
    eval_steps=20,
    logging_steps=1,
    eval_accumulation_steps=5,
    remove_unused_columns=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    compute_metrics=compute_metrics,
)

trainer.train()

データセットのサイズを大きくすると、メモリでエラーが発生します。バッチサイズ等はうまく設定しないと動かすことができなさそうです。

実行すると以下のようなログが出力されます。最初の6個の数値はStep, Training Loss, Mean Loss, Mean Accuracy, Overall Accuracyと全体に対しての結果です。次の2つのブロックは、Per Category IouとPerCategory Accuracyです。今回は50枚だけを学習させたため、含まれないクラスが存在し、それらの結果はnanとなっています。

下記は、終了時の結果ですが、カテゴリーのIOUの精度も、クラス分類もまだまだ学習できていないことがわかります。学習させるには、もっとデータを増やすかEPOCH回数を増やす必要がありそうです。

ログ

1000枚くらいまで増やして学習させてみましたが、まだまだ学習できていませんでした。

学習結果の保存

学習したモデルの保存します。./trained-modelという名前で保存されます。

trainer.save_model("./trained-model")

学習結果の利用(推論コード)

ログを見る限り学習できてなさそうですが、予測を行ってみます。まず、画像を読み込みます。

ds = load_dataset("scene_parse_150", split="train[:50]")
image = ds[0]["image"]
image

次に学習したモデルとimage_processorを読み込みます。image_processorは学習で使ったものと同じものを指定します。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
model = AutoModelForSemanticSegmentation.from_pretrained("./trained-model", id2label=id2label, label2id=label2id)
model = model.to(device)
checkpoint = "nvidia/mit-b0"
image_processor = AutoImageProcessor.from_pretrained(checkpoint, reduce_labels=True)

画像を処理してモデルで予測させます

encoding = image_processor(image, return_tensors="pt")
pixel_values = encoding.pixel_values.to(device)
outputs = model(pixel_values=pixel_values)
logits = outputs.logits.cpu()

画像として表示させる

結果を画像として表示させます

import matplotlib.pyplot as plt
import numpy as np

def ade_palette():
    """ADE20K palette that maps each class to RGB values."""
    return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
            [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
            [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
            [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
            [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
            [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
            [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
            [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
            [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
            [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
            [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
            [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
            [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
            [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
            [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
            [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
            [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
            [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
            [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
            [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
            [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
            [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
            [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
            [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
            [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
            [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
            [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
            [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
            [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
            [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
            [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
            [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
            [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
            [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
            [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
            [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
            [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
            [102, 255, 0], [92, 0, 255]]
upsampled_logits = nn.functional.interpolate(
    logits,
    size=image.size[::-1],
    mode="bilinear",
    align_corners=False,
)

pred_seg = upsampled_logits.argmax(dim=1)[0]


color_seg = np.zeros((pred_seg.shape[0], pred_seg.shape[1], 3), dtype=np.uint8)
palette = np.array(ade_palette())
for label, color in enumerate(palette):
    color_seg[pred_seg == label, :] = color
color_seg = color_seg[..., ::-1]  # convert to BGR

img = np.array(image) * 0.5 + color_seg * 0.5  # plot the image with the segmentation map
img = img.astype(np.uint8)

plt.figure(figsize=(15, 10))
plt.imshow(img)
plt.show()

結果を見ると、なんとなくセグメンテーションを行おうとしていることはわかりますが、まだまだ学習していないことも確認できます。

その他情報

SegFormerはkaggleのコンテスト「HuBMAP + HPA – Hacking the Human Body」で実際に利用しました。EfficientNetを利用したセグメンテーションモデルと比較しても高い性能でした。

実際に利用した学習コードは以下にあります

GitHub : https://github.com/aruaru0/SegFormer-test-codes

学習した結果を使った、推論コードは以下を参考にしてください。

Kaggleノートブック

まとめ

セグメンテーションモデルSegFormerの学習・推論について解説しました。今回は、サブセットで学習させたため、ちゃんと学習できていないですが、手順は説明できたのではないかと思います。

ゼロショットのセグメンテーションを実現するSAMについては以下を参照してください

おすすめ書籍

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

記事URLをコピーしました