機械学習
記事内に商品プロモーションを含む場合があります

Cleanlabを使ってデータセット中のノイズ画像を検出する|ノイズラベル検出の実践ガイド

tadanori

この記事では、Cleanlabというノイズラベルを検出するライブラリを使って、データセット中のノイズラベルを検出する方法について紹介します

Cleanlabとは

Cleanlabとは

ディープラーニングのモデルの性能を引き出すには、学習させるデータの品質も重要です。しかしながら、実際に学習に疲れわれるデータには、ラベルのつけ間違いなどのノイズが含まれており、その影響により精度が低下してしまうことがあります。

このような、ラベル誤りや異常なデータを特定するツールとしてCleanlabがあります。

公式サイト:https://docs.cleanlab.ai/v2.0.0/index.html

Cleanlabは、データセット中のノイズラベルを検出するためのPythonライブラリです。このライブラリを使うことで、データのクレンジング処理を簡単化し、モデルのトレーニングを効率的に行うことができます。

利用手順

Cleanlabの使用手順は以下の通りです。

  1. データセットを準備し、モデルをトレーニングする。
  2. Cleanlabを用いてモデルの予測結果と実際のラベルを比較し、ノイズラベルを特定する。
  3. 特定されたノイズラベルを修正または除去し、クリーンなデータセットを再構築する。

Cleanlabは、機械学習の実務においてデータ品質を維持するための強力なツールです。ここでは、MNISTの手書き文字データを使って実際にノイズ画像の検出をやってみたいと思います。

MNISTの学習

MNISTの学習については過去の記事で解説していますのでそちらを参照してください。

あわせて読みたい
手書き文字(MNIST)認識をCNNでやってみる【初級 深層学習講座】
手書き文字(MNIST)認識をCNNでやってみる【初級 深層学習講座】

ここでは、上記記事の手順で学習後のモデルに対してcleanlabを使ってノイズ画像を検出します。

なお、今回の実験コードはこちらにあります(Google Colabで動作します)

cleanlabによるノイズ画像検出

cleanlabを利用するには、ライブラリのインストールが必要です。以下のコマンドを実行してcleanlabをインストールしてください。

!pip install cleanlab

ノイズ画像の検出

ノイズ画像を検出するために、推論モードでモデルを実行しoutputsと正解ラベルy_gtを保存しておきます。

今回は、訓練データの中のノイズデータを検索したいので、test_dataloaderが参照するデータセットにtrain_datasetを設定しています。

下記のコードを実行することで、全ての訓練画像のy_outputsに予測結果が、y_gtに正解ラベルが格納されます。

outputssoftmaxをかけて0~1のデータにしています。これは、cleanlabが0~1範囲の入力を期待しているためです。

test_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=False)


model.eval()
y_pred = []
y_gt = []
y_outputs = None
for batch in tqdm(test_loader):
    inputs, targets = batch
    with torch.no_grad() :
      outputs = model(inputs.to(device))
    y_gt += targets.tolist()
    y_pred += outputs.argmax(axis=1).tolist()
    if y_outputs is None:
      y_outputs = outputs.detach().cpu().softmax(dim=1).numpy()
    else:
      y_outputs = np.vstack((y_outputs, outputs.detach().cpu().softmax(dim=1).numpy()))

次に、cleanlabfind_label_issues関数を実行します。

引数は、正解ラベル、予測結果、filter_by=xxxでモードです。モードはいくつかありますが、今回は「不正解のデータ(predicted_neq_given)」と、「prune_by_noise_rate」の2つのモードを利用してみました。

他にもオプションが多数あります。詳しくは公式サイトを確認してください。

import cleanlab
from cleanlab.filter import find_label_issues

# valid = find_label_issues(y_gt, y_outputs, filter_by="confident_learning")

print(f"正解と異なる推論データは{sum([1 if x != y else 0 for x, y in zip(y_gt, y_pred)])}個です")
valid = find_label_issues(y_gt, y_outputs, filter_by="predicted_neq_given")
print(f"predicted_neq_givenモードでノイズ画像を{sum(valid)}個発見しました")

valid = find_label_issues(y_gt, y_outputs, filter_by="prune_by_noise_rate")
print(f"prune_by_noise_rateモードでノイズ画像を{sum(valid)}個発見しました")

実行すると、正解と推論結果が異なるものが215個あることがわかります。predicted_neq_givenモードでは、異なるもの全てが検出されます。

prune_by_noise_rateでは、さらに少ない3つがノイズ画像として検出されました。

正解と異なる推論データは215個です
predicted_neq_givenモードでノイズ画像を215個発見しました
prune_by_noise_rateモードでノイズ画像を3個発見しました

検出したノイズ画像を表示する

実際に、ノイズ画像と判定された画像を表示してみます。以下は、表示させるコードです。

for idx, (gt, flg) in enumerate(zip(y_gt, valid)):
  if flg == True :
    img, label = train_dataset[idx]
    plt.imshow(img.squeeze(), cmap="gray")
    plt.title(f"gt={gt}")
    plt.show()

結果は以下になります。

1つ目はにしか見えないですね。

ノイズ画像1

2つ目はに見えます

ノイズ画像2

3つめはに見えます

ノイズ画像3

検出結果を見ると、概ねノイズデータと考えても良さそうなデータが選択されています。

実際には、検出したデータを削除して再度学習を回していくことで精度を上げていきいます。

まとめ

データセットの中からノイズデータを検出するcleanlabの使い方について解説しました。比較的手軽に利用できるので、使ってみるのも良いかもしれません。


おすすめ書籍

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

記事URLをコピーしました