SegFormerによる学習と推論:Transformerモデルでセグメンテーション
SegFormerは、セグメンテーション(Segmentation)タスクに特化したTransformerモデルです。この記事では、SegFormerを用いて独自のカスタムデータを使った学習と推論の手順を詳しく解説します。記事内のコードはGoogle Colabで実行できますので、Colab上で実際に動かして結果を確認することができます。
SegFormerとは
SegFormerは、セグメンテーション(Segmentation)のタスクを行うためのモデルです。セグメンテーションは、画像の各ピクセルにクラスを割り当てるもので、画像中に存在する人やオブジェクトを画素毎にクラス分けするタスクになります。
SegFormerは、このタスクにTransformerを応用したもので、既存のモデルと比較して計算コストを抑えて制度を向上したモデルになります。
論文(arXiv):https://arxiv.org/abs/2105.15203
GitHub: https://github.com/NVlabs/SegFormer
今回は、SegFormerを使って学習・推論させてみます。
SegFormerで学習させてみる
コードはGoogle Colabでテストできます。コードをコードセルに貼り付けながら実行してください
SegfFrmerのインストール
SegFormerはHugging Faceで提供されていますので、今回はそれを利用します。利用するには以下のパッケージのインストールが必要です。
!pip install transformers
!pip install datasets
!pip install evaluate
!pip install git+https://github.com/huggingface/accelerate
重要:インストールが終わったら、Google Colabの「ランタイム」→「ランタイムを再起動」を行ってリセットしてください。
再起動しないと、エラーが発生します。
データセットの読み込み
データセットは、HuggingFaceで準備されているデータセットを利用します。
split="train[:xx]"
のようにオプションを追加すると、データの一部だけをロードできます。今回は、チュートリアルに習って50個だけを読み込むことにしました。
ちなみに、あとで書きますが50個のデータではちゃんと学習できません。すべてのデータを利用すると、Google ColabのT4(GPU)で35時間と表示されました。今回は動作確認のために小さなデータで試しています。
ダウンロードはすべて行われているようです。
from datasets import load_dataset
# データセットをロード。高速化のために一部だけを利用
ds = load_dataset("scene_parse_150", split="train[:50]")
ds
データをtrainとtestデータに分割します
ds = ds.train_test_split(test_size=0.2)
train_ds = ds["train"]
test_ds = ds["test"]
train_dsの中身を確認します。Colabでは以下のコードでデータを確認できます。
train_ds[0]["image"]
アノテーションデータも確認しておきます。アノテーションデータは、以下のようにすれば画像としてみることができます(普通に画像として表示させると真っ黒になります)。
import matplotlib.pyplot as plt
plt.imshow(train_ds[0]["annotation"])
アノテーションデータには、150種類のクラスがありますが、これを辞書型にいれておきます。ここで作成した辞書id2label
, label2id
は、モデルのオプションに指定していますが、これは行わなくても良いです。
# 150種類のラベルのラベル名とIDの対応
import json
from huggingface_hub import cached_download, hf_hub_url
repo_id = "huggingface/label-files"
filename = "ade20k-id2label.json"
id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename, repo_type="dataset")), "r"))
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}
num_labels = len(id2label)
プリプロセス(データ拡張)
次に画像の事前処理を行うimage_processor
を定義します。
データセットのクラスが1始まりなので、0始まりに直す必要がありますが、これはreduce_labels=True
で行うことができます。checkpoint
は利用するモデルと合わせます。今回はnvidia/mit-b0
を利用しますので、これを設定しています。
from transformers import AutoImageProcessor
checkpoint = "nvidia/mit-b0"
image_processor = AutoImageProcessor.from_pretrained(checkpoint, reduce_labels=True)
学習時にデータ拡張として、画像にジッターを付加します。データ拡張を行うことで学習効果を高めることができます。まず、ジッターを付加する関数を定義します。
from torchvision.transforms import ColorJitter
jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)
次に、データを変換する関数を定義します。処理的には、入力された画像とアノテーションデータに対して、image_processorを通す諸rになります。trainはジッターを付加しますが、testはジッターを付加しません。
def train_transforms(example_batch):
images = [jitter(x) for x in example_batch["image"]]
labels = [x for x in example_batch["annotation"]]
inputs = image_processor(images, labels)
return inputs
def val_transforms(example_batch):
images = [x for x in example_batch["image"]]
labels = [x for x in example_batch["annotation"]]
inputs = image_processor(images, labels)
return inputs
これをデータセットの変換関数として設定してデータセットの定義は終了です。
train_ds.set_transform(train_transforms)
test_ds.set_transform(val_transforms)
学習(トレーニング)
メトリックの定義
評価メトリックを定義します。これも用意されているものを利用します。
import evaluate
metric = evaluate.load("mean_iou")
定義した、compute_metrics
関数はTrainerの引数で与えます。
import numpy as np
import torch
from torch import nn
def compute_metrics(eval_pred):
with torch.no_grad():
logits, labels = eval_pred
logits_tensor = torch.from_numpy(logits)
logits_tensor = nn.functional.interpolate(
logits_tensor,
size=labels.shape[-2:],
mode="bilinear",
align_corners=False,
).argmax(dim=1)
pred_labels = logits_tensor.detach().cpu().numpy()
metrics = metric.compute(
predictions=pred_labels,
references=labels,
num_labels=num_labels,
ignore_index=255,
reduce_labels=False,
)
for key, value in metrics.items():
if type(value) is np.ndarray:
metrics[key] = value.tolist()
return metrics
モデルの読み込み
モデルの読み込みは以下のようになります。
from transformers import AutoModelForSemanticSegmentation, TrainingArguments, Trainer
model = AutoModelForSemanticSegmentation.from_pretrained(checkpoint, id2label=id2label, label2id=label2id)
学習
まず、TrainingArgumentsで学習のパラメータを設定します。出力が大きいのでバッチサイズは小さくしないとメモリが不足しますので注意してください。
最後にTrainer
にtrain_ds
、test_ds
と、compute_metrics
を指定し、trainer.train()
で学習を進めます。
training_args = TrainingArguments(
output_dir="./",
learning_rate=6e-5,
num_train_epochs=50,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
save_total_limit=3,
evaluation_strategy="steps",
save_strategy="steps",
save_steps=20,
eval_steps=20,
logging_steps=1,
eval_accumulation_steps=5,
remove_unused_columns=False,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=test_ds,
compute_metrics=compute_metrics,
)
trainer.train()
データセットのサイズを大きくすると、メモリでエラーが発生します。バッチサイズ等はうまく設定しないと動かすことができなさそうです。
実行すると以下のようなログが出力されます。
最初の6個の数値はStep, Training Loss, Mean Loss, Mean Accuracy, Overall Accuracyと全体に対しての結果です。次の2つのブロックは、Per Category IouとPerCategory Accuracyです。今回は50枚だけを学習させたため、含まれないクラスが存在し、それらの結果はnan
となっています。
下記は、終了時の結果ですが、カテゴリーのIOUの精度も、クラス分類もまだまだ学習できていないことがわかります。学習させるには、もっとデータを増やすかEPOCH回数を増やす必要がありそうです。
1000枚くらいまで増やして学習させてみましたが、まだ学習できていませんでした
後にも書いていますがkaggleのコンペで利用したときはちゃんと学習できていたので、データセットかEPOCH数の問題だと思われます。
学習結果の保存
学習したモデルの保存します。./trained-model
という名前で保存されます。
trainer.save_model("./trained-model")
学習結果の利用(推論コード)
ログを見る限り学習できてなさそうですが、予測を行ってみます。まず、画像を読み込みます。
ds = load_dataset("scene_parse_150", split="train[:50]")
image = ds[0]["image"]
image
次に学習したモデルとimage_processor
を読み込みます。image_processor
は学習で使ったものと同じものを指定します。
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForSemanticSegmentation.from_pretrained("./trained-model", id2label=id2label, label2id=label2id)
model = model.to(device)
checkpoint = "nvidia/mit-b0"
image_processor = AutoImageProcessor.from_pretrained(checkpoint, reduce_labels=True)
画像を処理してモデルで予測させます
encoding = image_processor(image, return_tensors="pt")
pixel_values = encoding.pixel_values.to(device)
outputs = model(pixel_values=pixel_values)
logits = outputs.logits.cpu()
画像として表示させる
結果を画像として表示させます
import matplotlib.pyplot as plt
import numpy as np
def ade_palette():
"""ADE20K palette that maps each class to RGB values."""
return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
[102, 255, 0], [92, 0, 255]]
upsampled_logits = nn.functional.interpolate(
logits,
size=image.size[::-1],
mode="bilinear",
align_corners=False,
)
pred_seg = upsampled_logits.argmax(dim=1)[0]
color_seg = np.zeros((pred_seg.shape[0], pred_seg.shape[1], 3), dtype=np.uint8)
palette = np.array(ade_palette())
for label, color in enumerate(palette):
color_seg[pred_seg == label, :] = color
color_seg = color_seg[..., ::-1] # convert to BGR
img = np.array(image) * 0.5 + color_seg * 0.5 # plot the image with the segmentation map
img = img.astype(np.uint8)
plt.figure(figsize=(15, 10))
plt.imshow(img)
plt.show()
結果を見ると、なんとなくセグメンテーションを行おうとしていることはわかりますが、まだまだ学習できていないことが確認できます。もう少しデータを増やすか、Epoch数を増やす必要がありそうです。
その他情報
SegFormerはkaggleのコンテスト「HuBMAP + HPA – Hacking the Human Body」で実際に利用しました。EfficientNetを利用したセグメンテーションモデルと比較しても高い性能でした。
実際に利用した学習コードは以下にあります
GitHub : https://github.com/aruaru0/SegFormer-test-codes
学習した結果を使った、推論コードは以下を参考にしてください。
Kaggleノートブック
まとめ
セグメンテーションモデルSegFormerの学習・推論について解説しました。今回は、サブセットで学習させたため、ちゃんと学習できていないですが、手順は説明できたのではないかと思います。